Cod sursa(job #3312468)

Utilizator vlad.ulmeanu30Ulmeanu Vlad vlad.ulmeanu30 Data 28 septembrie 2025 15:24:22
Problema Aho-Corasick Scor 25
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 10.36 kb
#include <bits/stdc++.h>
#define aaa system("read -r -p \"Press enter to continue...\" key");
#define dbg(x) std::cerr<<(#x)<<": "<<(x)<<'\n',aaa
#define dbga(x,n) std::cerr<<(#x)<<"[]: ";for(int _=0;_<n;_++)std::cerr<<x[_]<<' ';std::cerr<<'\n',aaa
#define dbgs(x) std::cerr<<(#x)<<"[stl]: ";for(auto _:x)std::cerr<<_<<' ';std::cerr<<'\n',aaa
#define dbgp(x) std::cerr<<(#x)<<": "<<x.fi<<' '<<x.se<<'\n',aaa
#define dbgsp(x) std::cerr<<(#x)<<"[stl pair]:\n";for(auto _:x)std::cerr<<_.fi<<' '<<_.se<<'\n';aaa
#define fi first
#define se second

constexpr uint32_t ct229 = (1 << 29) - 1;
constexpr uint64_t M61 = (1ULL << 61) - 1, M61_2x = M61 * 2;

///folosit pentru bucati puteri de 2 din s.
struct PrefixInfo {
    std::array<uint64_t, 2> hh_ps; ///hash-ul pref, shade.
    std::pair<int, int> sh_loc; ///unde incepe, unde se termina (inclusiv) shade-ul prefixului.
    int lev; ///pot sa am mai multe intrari de PrefixInfo indentice (acelasi pref_hh, acelasi shade), le compresez <=> compresie DAG.
};

struct TsInfo {
    std::array<uint64_t, 2> hh_ps; ///hash pref (primele pw_msb caractere), suff (restul).
    int suff_len; ///lungimea sufixului.
    int ind; ///indexul in ts.
};

///a = a * b % M61.
inline uint64_t mul(uint64_t a, uint64_t b) {
    uint64_t a_hi = a >> 32, a_lo = (uint32_t)a, b_hi = b >> 32, b_lo = (uint32_t)b, ans = 0, tmp = 0;

    tmp = a_hi * b_lo + a_lo * b_hi;
    tmp = ((tmp & ct229) << 32) + (tmp >> 29);
    tmp += (a_hi * b_hi) << 3;

    ans = (tmp >> 61) + (tmp & M61);
    tmp = a_lo * b_lo;

    ans += (tmp >> 61) + (tmp & M61);
    ans = (ans >= M61_2x? ans - M61_2x: (ans >= M61? ans - M61: ans));
    return ans;
}

inline uint64_t hh_add_char(uint64_t hh, uint8_t ch, uint64_t base) {
    hh = mul(hh, base) + ch;
    return (hh >= M61? hh - M61: hh);
}

inline uint64_t hh_rm_char(uint64_t hh, uint8_t ch, uint64_t base_pw) {
    uint64_t sub = mul(ch, base_pw);
    return (hh >= sub? hh - sub: hh + M61 - sub);
}

///TODO sunt putine combinatii de ch_bye si ch, poti sa faci ceva cu ele?
inline uint64_t hh_roll(uint64_t hh, uint8_t ch_bye, uint8_t ch, uint64_t base, uint64_t base_pw) {    
    hh = hh_rm_char(hh, ch_bye, base_pw);
    return hh_add_char(hh, ch, base);
}

///s_cuts[0 .. n]. 0 <= l <= r < n.
inline uint64_t hh_cut(const std::vector<uint64_t> &s_cuts, const std::vector<uint64_t> &base_pws, int l, int r) {
    uint64_t sub = mul(s_cuts[l], base_pws[r-l+1]);
    return (s_cuts[r+1] >= sub? s_cuts[r+1] - sub: s_cuts[r+1] + M61 - sub);
}

///stim ca prefs[i].hh_ps[0] == ts_tmp[j].hh_ps[0] pentru oricare pref_l <= i <= pref_r, ts_l <= j <= ts_r.
///trebuie sa verificam sa numaram de cate ori apar sufixele ts_tmp[j].hh_ps[1] in shade-urile din prefs[pref_l .. pref_r].
///bonus: tratam in aceeasi parcurgere elementele din ts_tmp[ts_l .. ts_r] cu acelasi suff_len.
void solve_group(
    std::vector<PrefixInfo> &prefs, int pref_l, int pref_r,
    std::vector<TsInfo> &ts_tmp, int ts_l, int ts_r,
    const std::vector<uint64_t> &s_cuts, const std::vector<uint64_t> &base_pws, std::vector<int> &ts_count
) {
    int i = ts_l;

    while (i <= ts_r && ts_tmp[i].suff_len == 0) { ///daca suff_len == 0, orice din prefs[pref_l .. pref_r] e match.
        if (i == ts_l) {
            for (int j = pref_l; j <= pref_r; j++) ts_count[ts_tmp[i].ind] += prefs[j].lev;
        } else ts_count[ts_tmp[i].ind] = ts_count[ts_tmp[i-1].ind];
        i++;
    }

    bool sorted_shades;
    while (i <= ts_r) {
        sorted_shades = false;
        if (i == ts_l || ts_tmp[i].suff_len > ts_tmp[i-1].suff_len) {
            ///recalculam campul de shade-uri a.i. arata primele suff_len caractere din el.
            for (int j = pref_l; j <= pref_r; j++) {
                if (prefs[j].sh_loc.se - prefs[j].sh_loc.fi + 1 >= ts_tmp[i].suff_len) {
                    prefs[j].hh_ps[1] = hh_cut(s_cuts, base_pws, prefs[j].sh_loc.fi, prefs[j].sh_loc.fi + ts_tmp[i].suff_len - 1);
                }
            }

            if (i < ts_r && ts_tmp[i].suff_len == ts_tmp[i+1].suff_len) {
                ///are rost sa sortam shade-urile partiale doar daca am cel putin doua sufixe de lungime egala.
                std::sort(prefs.begin() + pref_l, prefs.begin() + pref_r + 1, [](const PrefixInfo &a, const PrefixInfo &b){ return a.hh_ps[1] < b.hh_ps[1]; });
                sorted_shades = true;
            }
        }

        if (!sorted_shades) {
            ///am un singur t cu acest suff_len in grupa. trec liniar prin toate prefs[], verific sa fie shade-ul indeajuns de mare.
            for (int j = pref_l; j <= pref_r; j++) {
                if (prefs[j].sh_loc.se - prefs[j].sh_loc.fi + 1 >= ts_tmp[i].suff_len && prefs[j].hh_ps[1] == ts_tmp[i].hh_ps[1]) {
                    ts_count[ts_tmp[i].ind] += prefs[j].lev;
                }
            }
            i++;
        } else {
            int j = pref_l, z = i;
            while (z <= ts_r && ts_tmp[z].suff_len == ts_tmp[i].suff_len) {
                if (z > i && ts_tmp[z].hh_ps[1] == ts_tmp[z-1].hh_ps[1]) ts_count[ts_tmp[z].ind] = ts_count[ts_tmp[z-1].ind]; ///t duplicat.
                else {
                    ///hash-urile shade-urilor taiate sunt sortate, deci pot sa le parcurg liniar de unde m-a lasat ultimul t din grupa.
                    while (j <= pref_r && prefs[j].hh_ps[1] < ts_tmp[z].hh_ps[1]) j++;
                    while (j <= pref_r && prefs[j].hh_ps[1] == ts_tmp[z].hh_ps[1]) {
                        if (prefs[j].sh_loc.se - prefs[j].sh_loc.fi + 1 >= ts_tmp[z].suff_len) ts_count[ts_tmp[z].ind] += prefs[j].lev;
                        j++;
                    }
                }
                z++;
            }

            i = z;
        }
    }
}

int main() {
    std::ifstream fin("ahocorasick.in");
    std::ofstream fout("ahocorasick.out");

    std::string s; fin >> s;
    int n = s.size();

    std::mt19937_64 mt(time(NULL));
    uint64_t base = std::uniform_int_distribution<uint64_t>(27, M61 - 1)(mt);
    
    std::vector<uint64_t> base_pws(n, 1);
    for (int i = 1; i < n; i++) base_pws[i] = mul(base_pws[i-1], base);

    std::vector<uint64_t> s_cuts(n+1); ///hashurile peste prefixe.
    for (int i = 0; i < n; i++) s_cuts[i+1] = hh_add_char(s_cuts[i], s[i]-'a'+1, base);

    int q; fin >> q;
    std::vector<std::string> ts(q);
    std::vector<int> ts_count(q), ts_order(q);
    for (int i = 0; i < q; i++) fin >> ts[i];

    std::iota(ts_order.begin(), ts_order.end(), 0);
    std::sort(ts_order.begin(), ts_order.end(), [&ts](int a, int b) { return ts[a].size() < ts[b].size(); });
    int ts_ind_l = 0;

    std::vector<PrefixInfo> prefs(n);
    std::vector<TsInfo> ts_tmp(q);

    for (int z = 0; (1 << z) <= n; z++) {
        ///generez toate subsecv de lungime 2**z din s, shade-urile lor, tin minte locatiile shade-urilor.
        prefs[0].hh_ps = {0, 0};
        prefs[0].sh_loc = std::make_pair((1<<z), std::min(n-1, (1<<(z+1))-2));

        prefs[0].hh_ps[0] = s_cuts[1<<z];
        for (int i = prefs[0].sh_loc.fi; i <= prefs[0].sh_loc.se; i++) prefs[0].hh_ps[1] = hh_add_char(prefs[0].hh_ps[1], s[i]-'a'+1, base);

        for (int i = (1 << z), j = 1; i < n; i++, j++) {
            prefs[j].hh_ps[0] = hh_roll(prefs[j-1].hh_ps[0], s[j-1]-'a'+1, s[i]-'a'+1, base, base_pws[(1<<z)-1]);
            prefs[j].sh_loc = std::make_pair(i+1, std::min(n-1, i+1 + (1<<z)-2));

            if (z > 0) {
                if (i+1 + (1<<z) - 2 < n) {
                    prefs[j].hh_ps[1] = hh_roll(prefs[j-1].hh_ps[1], s[i]-'a'+1, s[i+1+(1<<z)-2]-'a'+1, base, base_pws[(1<<z)-2]);
                } else {
                    prefs[j].hh_ps[1] = hh_rm_char(prefs[j-1].hh_ps[1], s[i]-'a'+1, base_pws[(1<<z)-2]);
                }
            }
        }

        int cnt_prefs = n + 1 - (1<<z);
        std::sort(prefs.begin(), prefs.begin() + cnt_prefs, [](const PrefixInfo &a, const PrefixInfo &b) { return a.hh_ps < b.hh_ps; });

        ///scapam de pref + shade identice.
        {
            int i = 0, y = 0;
            while (i < cnt_prefs) {
                int j = i;
                while (j < cnt_prefs && prefs[j].hh_ps == prefs[i].hh_ps && prefs[j].sh_loc.se - prefs[j].sh_loc.fi == prefs[i].sh_loc.se - prefs[i].sh_loc.fi) {
                    j++;
                }

                prefs[y].hh_ps = prefs[i].hh_ps;
                prefs[y].sh_loc = prefs[i].sh_loc;
                prefs[y++].lev = j - i;
                i = j;
            }

            cnt_prefs = y;
        }

        ///calculez raspunsul pentru elementele din ts[] care au MSB egal cu 2**z.
        while (ts_ind_l < q && (int)ts[ts_order[ts_ind_l]].size() < (1 << z)) ts_ind_l++;
        
        ///calculez hash-urile pentru prefix & sufix.
        int k = 0, ind = ts_order[ts_ind_l+k];
        while (ts_ind_l + k < q && (int)ts[ind].size() < (1 << (z+1))) {
            ts_tmp[k].hh_ps = {0, 0};
            for (int i = 0, j = 0; i < (int)ts[ind].size(); i++, j += (i == (1<<z))) ts_tmp[k].hh_ps[j] = hh_add_char(ts_tmp[k].hh_ps[j], ts[ind][i]-'a'+1, base);
            ts_tmp[k].suff_len = (int)ts[ind].size() - (1<<z);
            ts_tmp[k].ind = ind;

            k++;
            if (ts_ind_l + k < q) ind = ts_order[ts_ind_l+k];
        }

        std::sort(ts_tmp.begin(), ts_tmp.begin() + k, [](const TsInfo &a, const TsInfo &b) {
            if (a.hh_ps[0] != b.hh_ps[0]) return a.hh_ps[0] < b.hh_ps[0]; ///intai acelasi prefix ca sa putem face grupurile.
            if (a.suff_len != b.suff_len) return a.suff_len < b.suff_len; ///apoi aceeasi lungime. pentru aceeasi lungime putem calcula intr-o sg trecere prin partea cealalta a grupului. (aveam "<"..)
            return a.hh_ps[1] < b.hh_ps[1]; ///iar in final dupa sufix. daca si sufixul e identic, pot refolosi rezultatul de dinainte.
        });

        ///formez grupele pref - ts_tmp.
        {
            int i = 0, j, z = 0, y;

            while (z < k) {
                while (i < cnt_prefs && prefs[i].hh_ps[0] < ts_tmp[z].hh_ps[0]) i++;
                
                if (i < cnt_prefs && prefs[i].hh_ps[0] == ts_tmp[z].hh_ps[0]) {
                    y = z;
                    while (y < k && prefs[i].hh_ps[0] == ts_tmp[y].hh_ps[0]) y++;
                    j = i;
                    while (j < cnt_prefs && prefs[j].hh_ps[0] == ts_tmp[z].hh_ps[0]) j++;

                    solve_group(prefs, i, j-1, ts_tmp, z, y-1, s_cuts, base_pws, ts_count);
                    i = j; z = y;
                } else z++;
            }
        }

        ts_ind_l += k;
    }

    for (int cnt: ts_count) fout << cnt << '\n';            

    return 0;
}