Cod sursa(job #3156907)

Utilizator AleXutzZuDavid Alex Robert AleXutzZu Data 13 octombrie 2023 17:27:08
Problema Aho-Corasick Scor 95
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 5.19 kb
#include <iostream>
#include <vector>
#include <queue>
#include <functional>

struct Trie {
private:
    constexpr static const int K = 26;

    struct Vertex {
        std::vector<int> next;
        bool output = false;
        int fail = -1;
        int exit = -1;
        int occurrences = 0;

        Vertex() {
            next.resize(K, -1);
        }
    };

    std::vector<Vertex> tree;
    std::vector<int> ans;
    std::vector<int> order;

    int &next(int node, char c) {
        return tree[node].next[c - 'a'];
    }

    bool &output(int node) {
        return tree[node].output;
    }

    int &fail(int node) {
        return tree[node].fail;
    }

    void push_links() {
        std::queue<int> queue;

        for (char i = 'a'; i <= 'z'; ++i) {
            int s = next(0, i);
            if (s > 0) {
                queue.push(s);
                fail(s) = 0;
            }
        }

        while (!queue.empty()) {
            int top = queue.front();
            queue.pop();

            order.push_back(top);

            for (char i = 'a'; i <= 'z'; ++i) {
                int s = next(top, i);
                if (s == -1) continue;

                queue.push(s);
                int state = fail(top);
                while (next(state, i) == -1) state = fail(state);
                fail(s) = next(state, i);
            }
        }
    }


public:
    Trie() {
        tree.emplace_back();
    }

    int add_word(const std::string &str) {
        int node = 0;
        for (auto i: str) {
            if (next(node, i) == -1) {
                next(node, i) = (int) tree.size();
                tree.emplace_back();
            }
            node = next(node, i);
        }
        output(node) = true;
        return node;
    }

    void compute() {
        fail(0) = 0;
        order.push_back(0);


        for (char i = 'a'; i <= 'z'; ++i) {
            if (next(0, i) == -1) next(0, i) = 0;
        }

        push_links();
        ans.resize(tree.size() + 10, 0);
    }

    void calc_exit(int node) {
        if (tree[node].exit >= 0 || tree[node].exit == -2) return;

        tree[node].exit = -2;
        int cpy = fail(node);

        while (cpy != 0) {
            if (output(cpy)) {
                tree[node].exit = cpy;
                return;
            }

            calc_exit(cpy);
            if (tree[cpy].exit >= 0) {
                tree[node].exit = tree[cpy].exit;
                return;
            }
            cpy = fail(cpy);
        }
    }

    void advance(int &state, char c) {
        while (next(state, c) == -1) state = fail(state);
        state = next(state, c);

        tree[state].occurrences++;
    }

    std::vector<int> get_answer() {
        for (auto i = order.rbegin(); i != order.rend(); ++i) {
            ans[*i] += tree[*i].occurrences;
            calc_exit(*i);

            if (tree[*i].exit >= 0) ans[tree[*i].exit] += ans[*i];
        }
        return ans;
    }
};

struct InputParser {
private:
    FILE *file;
    int ptr{};
    char *buffer;

    char read() {
        if (ptr == 4096) {
            fread(buffer, 1, 4096, file);
            ptr = 0;
        }
        return buffer[ptr++];
    }

public:
    explicit InputParser(const std::string &file) {
        this->file = fopen(file.c_str(), "r");
        ptr = 4096;
        buffer = new char[4096];
    }

    inline InputParser &operator>>(char &c) {
        c = read();
        return *this;
    }

    inline InputParser &operator>>(std::string &str) {
        char c;
        while (!std::isalpha(c = read()));
        str += c;
        while (std::isalpha(c = read())) {
            str += c;
        }
        return *this;
    }

    inline InputParser &operator>>(int &n) {
        char c;
        while (!std::isdigit(c = read()));
        n = c - '0';
        while (std::isdigit(c = read())) n = n * 10 + c - '0';
        return *this;
    }
};

class OutputParser {
private:
    FILE *file;
    char *buffer;
    int ptr;

    void write(char ch) {
        if (ptr == 50000) {
            fwrite(buffer, 1, 50000, file);
            ptr = 0;
        }
        buffer[ptr++] = ch;
    }


public:
    explicit OutputParser(const std::string &file) {
        this->file = fopen(file.c_str(), "w");
        buffer = new char[50000]();
        ptr = 0;
    }

    ~OutputParser() {
        fwrite(buffer, 1, ptr, file);
        fclose(file);
    }

    inline OutputParser &operator<<(int vu32) {
        if (vu32 <= 9) {
            write(vu32 + '0');
        } else {
            (*this) << (vu32 / 10);
            write(vu32 % 10 + '0');
        }
        return *this;
    }

    inline OutputParser &operator<<(char ch) {
        write(ch);
        return *this;
    }
};

int main() {
    InputParser input("ahocorasick.in");
    OutputParser output("ahocorasick.out");
    Trie trie;
    std::string s;
    int n;
    input >> s >> n;


    std::vector<int> word_pos(n);
    for (int i = 0; i < n; ++i) {
        std::string word;
        input >> word;

        word_pos[i] = trie.add_word(word);
    }


    int state = 0;
    trie.compute();
    for (char i: s) {
        trie.advance(state, i);
    }

    auto ans = trie.get_answer();

    for (int i = 0; i < n; ++i) {
        output << ans[word_pos[i]] << '\n';
    }
    return 0;
}