CSES - Suffiksitaulukko Pythonilla

Huom! Käytä PyPyä, CPython on liian hidas.

Suffiksitaulukko

# Kokoaa suffiksitaulukon (SA) kaksinkertaistusmenetelmällä.
# s tulee päättyä merkkiin '$'.
def laske_sa(s):
    assert s[-1] == "$"
    n = len(s)

    # Alustava järjestys lasketaan yksittäisten merkkien perusteella
    merkit = sorted(set(s))
    järj_luku = [merkit.index(c) for c in s]

    pituus = 1
    while pituus < n:
        parit = []
        for i in range(n):
            if i + pituus < n:
                parit.append((järj_luku[i], järj_luku[i + pituus]))
            else:
                parit.append((järj_luku[i], 0))
        järj_luku = parien_järjestys(parit)
        pituus *= 2

    sa = [0] * n
    for i in range(n):
        sa[järj_luku[i]] = i
    return sa

# Laskujärjestäminen ensin parin toisen alkion, sitten ensimmäisen alkion
# mukaan. Palauttaa jokaiselle listan kohdalle järjestysluvun. Samoille
# pareille annetaan sama järjestysluku.
def parien_järjestys(parit):
    n = len(parit)

    määrät_0 = [0] * (n + 1)
    määrät_1 = [0] * (n + 1)
    for i in range(n):
        määrät_0[parit[i][0]+1] += 1
        määrät_1[parit[i][1]+1] += 1
    for i in range(1, n+1):
        määrät_0[i] += määrät_0[i-1]
        määrät_1[i] += määrät_1[i-1]

    järjestys_1 = [0] * n
    for i in range(n):
        kohta = määrät_1[parit[i][1]]
        järjestys_1[kohta] = i
        määrät_1[parit[i][1]] += 1

    järjestys_0 = [0] * n
    for i in järjestys_1:
        kohta = määrät_0[parit[i][0]]
        järjestys_0[kohta] = i
        määrät_0[parit[i][0]] += 1

    järj_luku = [0] * n
    edellinen = None
    laskuri = -1
    for i in järjestys_0:
        if parit[i] != edellinen: laskuri += 1
        järj_luku[i] = laskuri
        edellinen = parit[i]
    return järj_luku

LCP-taulukko

# Laskee LCP-taulukon merkkijonon ja sen suffiksitaulukon perusteella.
# s tulee päättyä merkkiin '$'.
def laske_lcp(s, sa):
    assert s[-1] == "$"
    n = len(s)

    järj_luku = [0] * n
    edeltävä = [0] * n
    for i in range(n):
        järj_luku[sa[i]] = i
        if i > 0: edeltävä[sa[i]] = sa[i-1]

    lcp = [0] * n
    pituus = 0
    for i in range(n-1):
        if pituus > 0: pituus -= 1
        while s[i + pituus] == s[edeltävä[i] + pituus]:
            pituus += 1
        lcp[järj_luku[i]] = pituus
    return lcp

Välin binäärihaku

# Palauttaa välin, jolle annettu toinen merkkijono sijoittuu
# suffiksitaulukossa. Pari (lo, hi) tarkoittaa puoliavointa väliä [lo, hi).
def sa_väli(s, sa, p):
    lo = 0
    hi = len(sa)
    for i, merkki in enumerate(p):
        lo = binäärihaku(lo, hi, lambda sa_i: s[sa[sa_i] + i] >= merkki)
        hi = binäärihaku(lo, hi, lambda sa_i: s[sa[sa_i] + i] > merkki)
    return (lo, hi)

# Etsii ensimmäisen kohdan, jossa ehto toteutuu.
def binäärihaku(lo, hi, ehto):
    while lo < hi:
        mid = (lo + hi) // 2
        if ehto(mid): hi = mid
        else: lo = mid + 1
    return lo