/*
SD 2023 - Trie
*/
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define ALPHABET_SIZE 26
#define ALPHABET "abcdefghijklmnopqrstuvwxyz"
typedef struct trie_node_t trie_node_t;
struct trie_node_t {
/* Value associated with key (set if end_of_word = 1) */
void* value;
/* 1 if current node marks the end of a word, 0 otherwise */
int end_of_word;
trie_node_t** children;
int n_children;
};
typedef struct trie_t trie_t;
struct trie_t {
trie_node_t* root;
/* Number of keys */
int size;
/* Generic Data Structure */
int data_size;
/* Trie-Specific, alphabet properties */
int alphabet_size;
char* alphabet;
/* Callback to free value associated with key, should be called when freeing */
void (*free_value_cb)(void*);
/* Optional - number of nodes, useful to test correctness */
int nNodes;
};
trie_node_t* trie_create_node(trie_t * trie) {
trie_node_t* node = (trie_node_t*) malloc(sizeof(trie_node_t));
node->value = NULL;
node->end_of_word = 0;
node->n_children = 0;
node->children = (trie_node_t**) calloc(trie->alphabet_size, sizeof(trie_node_t*));
trie->nNodes++;
return node;
}
trie_t* trie_create(int data_size, int alphabet_size, char* alphabet, void (*free_value_cb)(void*)) {
trie_t* trie = (trie_t*) malloc(sizeof(trie_t));
trie->size = 0;
trie->data_size = data_size;
trie->alphabet_size = alphabet_size;
trie->alphabet = alphabet;
trie->free_value_cb = free_value_cb;
trie->nNodes = 0;
trie->root = trie_create_node(trie);
char *s = "_";
trie->root->value = malloc(trie->data_size);
memcpy(trie->root->value, s, trie->data_size);
return trie;
}
void trie_insert(trie_t* trie, char* key, void* value) {
trie_node_t* current = trie->root;
int key_len = strlen(key);
for (int i = 0; i < key_len; i++) {
int index = key[i] - 'a';
if (current->children[index] == NULL) {
current->n_children++;
current->children[index] = trie_create_node(trie);
}
current = current->children[index];
}
current->end_of_word = 1;
current->value = malloc(trie->data_size);
memcpy(current->value, value, trie->data_size);
}
void* trie_search(trie_t* trie, char* key) {
trie_node_t* current = trie->root;
int key_len = strlen(key);
if(key[0] == '\0') {
*(int*)current->value = -1;
return current->value;
}
for (int i = 0; i < key_len; i++) {
int index = key[i] - 'a';
if (current->children[index] == NULL) {
return NULL;
}
current = current->children[index];
}
if (current != NULL && current->end_of_word) {
return current->value;
}
return NULL;
}
void trie_remove_helper(trie_t* trie, trie_node_t* node, char* key, int index) {
if (index == strlen(key)) {
if (node->end_of_word == 1) {
node->end_of_word = 0;
trie->size--;
if (node->n_children == 0) {
free(node->value);
node->value = NULL;
}
}
return;
}
if (node->children[key[index] - 'a'] == NULL) {
return;
}
trie_remove_helper(trie, node->children[key[index] - 'a'], key, index + 1);
if (node->children[key[index] - 'a']->n_children == 0 && node->children[key[index] - 'a']->end_of_word == 0) {
free(node->children[key[index] - 'a']);
node->children[key[index] - 'a'] = NULL;
node->n_children--;
trie->nNodes--;
}
}
void trie_remove(trie_t* trie, char* key) {
trie_node_t* curr = trie->root;
trie_remove_helper(trie, curr, key, 0);
}
static void __trie_free(trie_node_t* node, void (*free_value_cb)(void*)) {
for (int i = 0; i < 26; i++) {
if(node->children[i] == NULL) {
continue;
}
__trie_free(node->children[i], free_value_cb);
}
free(node->children);
node->children = NULL;
if (node->value != NULL) {
free(node->value);
node->value = NULL;
}
free(node);
node = NULL;
}
void trie_free(trie_t** pTrie) {
trie_t* trie = *pTrie;
__trie_free(trie->root, trie->free_value_cb);
free(trie);
*pTrie = NULL;
}
/* Needed for Lambda tests, ignore :) */
void cleanup_example_string(char* str) {
int len = strlen(str);
if(str[len-2] == '\\') {
str[len-2] = '\0';
}
}
int main() {
int n, value;
char alphabet[] = ALPHABET;
char buf[256], key[256], op;
trie_t* trie = trie_create(sizeof(int), ALPHABET_SIZE, alphabet, free);
fgets(buf, 256, stdin);
sscanf(buf, "%d\n", &n);
for(int i = 0; i < n; ++i) {
fgets(buf, 256, stdin);
sscanf(buf, "%c", &op);
if(op == 'i') {
sscanf(buf, "%c %s %d\n", &op, key, &value);
trie_insert(trie, key, &value);
} else if(op == 'r') {
sscanf(buf, "%c %s\n", &op, key);
cleanup_example_string(key);
printf("nNodes before removing %s: %d\n", key, trie->nNodes);
trie_remove(trie, key);
printf("nNodes after removing %s: %d\n", key, trie->nNodes);
} else if(op == 's') {
sscanf(buf, "%c %s\n", &op, key);
cleanup_example_string(key);
if(key[0] == '_') {
key[0] = '\0';
}
int* found = trie_search(trie, key);
printf("%s: ", key[0] == '\0' ? "_" : key);
if(found) {
printf("%d\n", *found);
} else {
printf("not found\n");
}
}
}
trie_free(&trie);
return 0;
}