/**
 * @file trie.c
 * @author Joe Wingbermuehle
 * @date 2007-07-03
 */

#include "trie.h"
#include <stdlib.h>
#include <string.h>

/** Trie structure. */
typedef struct TrieNode {
   struct TrieNode *children[256];
   void *data;                      /**< User-data. */
} TrieNode;

static void DestroyHelper(TrieNode *node);
static void *RemoveHelper(TrieNode *node, const char *str);
static const TrieNode *FindNode(const TrieNode *node, const char *str);
static int IsEmpty(const TrieNode *node);

/** Create an empty trie. */
Trie *CreateTrie() {
   TrieNode **result = malloc(sizeof(TrieNode*));
   *result = malloc(sizeof(TrieNode));
   memset(*result, 0, sizeof(TrieNode));
   return (Trie*)result;
}

/** Destroy a trie. */
void DestroyTrie(Trie *trie) {
   TrieNode **node = (TrieNode**)trie;
   DestroyHelper(*node);
   free(node);
}

/** Helper method for destroying a trie. */
void DestroyHelper(TrieNode *node) {

   int x;

   /* See if we're done. */
   if(node == NULL) {
      return;
   }

   /* Destroy children. */
   for(x = 0; x < 256; x++) {
      DestroyHelper(node->children[x]);
   }

   /* Destroy this node. */
   free(node);

}

/** Insert an item to the trie. */
void *InsertTrie(Trie *trie, const char *str, void *data) {

   TrieNode **node = (TrieNode**)trie;
   TrieNode *tp;
   void *result;
   int index;

   /* Handle empty strings. */
   tp = *node;
   if(!str) {
      result = tp->data;
      tp->data = data;
      return result;
   }

   /* Look for a match.
    * Here node points to the first child of the current level
    * and tp points to the current node. */
   while(*str) {

      index = (int)*(unsigned char*)str;
      if(!tp->children[index]) {
         /* Insert a new child and move to it. */
         tp->children[index] = malloc(sizeof(TrieNode));
         tp = tp->children[index];
         memset(tp, 0, sizeof(TrieNode));
      } else {
         /* Move down a level. */
         tp = tp->children[index];
      }

      /* Move to the next character of the string. */
      ++str;

   }

   /* Now tp points to the node to modify. */
   result = tp->data;
   tp->data = data;
   return result;

}

/** Remove an item from a trie. */
void *RemoveTrie(Trie *trie, const char *str) {

   TrieNode **node = (TrieNode**)trie;
   return RemoveHelper(*node, str);

}

/** Helper method for removing an item. */
void *RemoveHelper(TrieNode *node, const char *str) {

   TrieNode *tp;
   void *result;
   int index;

   /* See if we're at the end. */
   if(!str || !*str) {
      result = node->data;
      node->data = NULL;
      return result;
   }

   /* First we work our way down to the node to remove.
    * On the way back up we clear out empty nodes. */

   /* Get the child. */
   index = (int)*(unsigned char*)str;
   tp = node->children[index];
   if(!tp) {
      return NULL;
   }

   /* We have a match, move down the trie. */
   result = RemoveHelper(tp, str + 1);

   /* Remove the child if it is now empty. */
   if(IsEmpty(tp)) {
      free(tp);
      node->children[index] = NULL;
   }

   return result;
}

/** Find an item in a trie. */
void *FindTrie(const Trie *trie, const char *str) {

   const TrieNode **node = (const TrieNode**)trie;
   const TrieNode *tp;

   tp = FindNode(*node, str);
   if(tp) {
      return tp->data;
   } else {
      return NULL;
   }

}

/** Enumerate possible trie matches. */
int EnumerateTrie(const Trie *trie, const char *str,
   TrieEnumerator func, void *arg) {

   const TrieNode **node = (const TrieNode**)trie;
   const TrieNode *tp;
   int x;
   int count;

   /* Determine the level to enumerate. */
   tp = FindNode(*node, str);

   if(tp) {

      /* Loop over each character at this level. */
      count = 0;
      for(x = 0; x < 256; x++) {
         if(tp->children[x]) {
            if(func) {
               (func)((char)x, tp->children[x]->data, arg);
            }
            ++count;
         }
      }

   }

   return count;

}

/** Helper method for looking up a TrieNode. */
const TrieNode *FindNode(const TrieNode *node, const char *str) {

   /* Loop over each character of the string. */
   if(str) {
      while(*str && node) {
         node = node->children[(int)*(unsigned char*)str];
         ++str;
      }
   }

   return node;

}

/** Determine if a trie has any children or data. */
int IsEmpty(const TrieNode *node) {

   int x;

   /* Check for data. */
   if(node->data) {
      return 0;
   }

   /* Check for children. */
   for(x = 0; x < 256; x++) {
      if(node->children[x]) {
         return 0;
      }
   }

   return 1;

}

