Cod sursa(job #3258184)

Utilizator andrei.arnautuAndi Arnautu andrei.arnautu Data 21 noiembrie 2024 15:22:17
Problema Aho-Corasick Scor 100
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 5.23 kb
/**
 *  Worg
 */
#include <queue>
#include <string>
#include <memory>
#include <fstream>
#include <iostream>
#include <algorithm>
#include <unordered_map>

class TrieNode {
public:
    std::unordered_map<char, std::unique_ptr<TrieNode>> children;
    TrieNode* fail_link;
    std::vector<int> word_ids;
    int text_occurence_count;

    TrieNode() {
        children = std::unordered_map<char, std::unique_ptr<TrieNode>>();
        fail_link = nullptr;
        word_ids = std::vector<int>();
        text_occurence_count = 0;
    }
};


class Trie {
private:
    std::unique_ptr<TrieNode> root;
    int total_node_count;

public:
    Trie() {
        root = std::make_unique<TrieNode>();
        total_node_count = 1;  //  Just the root initially
    }

    TrieNode* get_root() const {
        return root.get();
    }

    int get_trie_size() const {
        return total_node_count;
    }

    void insert(const std::string& word, const int& word_id) {
        TrieNode* current_node = root.get();

        for (const char& c : word) {
            if (current_node->children.find(c) == current_node->children.end()) {
                current_node->children[c] = std::make_unique<TrieNode>();
                total_node_count += 1;
            }
            current_node = current_node->children[c].get();
        }

        current_node->word_ids.push_back(word_id);
    }
};


class AhoCorasick {
private:
    Trie trie;
    std::string text;
    int total_word_count;

    void use_automaton_on_text() {
        //  Go through the automaton states using the text as an input
        TrieNode* current_node = trie.get_root();
        TrieNode* root = trie.get_root();
        for (const auto& ch : text) {
            while (current_node != root && current_node->children.find(ch) == current_node->children.end()) {
                current_node = current_node->fail_link;
            }

            if (current_node->children.find(ch) != current_node->children.end()) {
                current_node = current_node->children[ch].get();
            }

            current_node->text_occurence_count += 1;
        }
    }

public:
    AhoCorasick(const std::string& _text) : text(_text) {
        trie = Trie();
        total_word_count = 0;
    }

    void insert_word_in_dictionary(const std::string& word, const int& word_id) {
        trie.insert(word, word_id);
        total_word_count += 1;
    }

    void compute_fail_links() {
        TrieNode* root = trie.get_root();
        root->fail_link = root;

        std::queue<TrieNode*> node_queue;
        for (const auto& p : root->children) {
            p.second->fail_link = root;
            node_queue.push(p.second.get());
        }

        while (!node_queue.empty()) {
            TrieNode* node = node_queue.front();
            node_queue.pop();

            for (const auto& p : node->children) {
                TrieNode* current_fail_link = node->fail_link;

                while (current_fail_link != root && current_fail_link->children.find(p.first) == current_fail_link->children.end()) {
                    current_fail_link = current_fail_link->fail_link;
                }

                if (current_fail_link->children.find(p.first) != current_fail_link->children.end()) {
                    current_fail_link = current_fail_link->children[p.first].get();
                } else {
                    current_fail_link = root;
                }

                p.second.get()->fail_link = current_fail_link;
                node_queue.push(p.second.get());
            }
        }
    }

    std::vector<int> find_word_appearances() {
        use_automaton_on_text();

        //  We want to obtain a node ordering from bottom to top (i.e. a node should always be placed after its descendants)
        std::queue<TrieNode*> node_queue;
        std::vector<TrieNode*> ordered_nodes;
        ordered_nodes.reserve(trie.get_trie_size());

        node_queue.push(trie.get_root());
        while (!node_queue.empty()) {
            TrieNode* current_node = node_queue.front();
            node_queue.pop();

            ordered_nodes.push_back(current_node);
            for (const auto& p : current_node->children) {
                node_queue.push(p.second.get());
            }
        }
        std::reverse(ordered_nodes.begin(), ordered_nodes.end());

        std::vector<int> word_appearances(total_word_count, 0);
        for (const auto& node : ordered_nodes) {
            for (const auto& word_id : node->word_ids) {
                word_appearances[word_id] += node->text_occurence_count;
            }

            node->fail_link->text_occurence_count += node->text_occurence_count;
        }

        return word_appearances;
    }
};



int main() {
    std::ifstream fin("ahocorasick.in");
    std::string text;
    fin >> text;
    AhoCorasick aho_corasick(text);

    int n;
    fin >> n;
    for (int i = 0; i < n; i++) {
        std::string dictionary_word;
        fin >> dictionary_word;
        aho_corasick.insert_word_in_dictionary(dictionary_word, i);
    }
    fin.close();

    std::ofstream fout("ahocorasick.out");

    aho_corasick.compute_fail_links();
    std::vector<int> word_appearances = aho_corasick.find_word_appearances();
    for (const auto& x : word_appearances) {
        fout << x << '\n';
    }
    fout.close();

    return 0;
}