Cod sursa(job #2413638)

Utilizator stoianmihailStoian Mihail stoianmihail Data 23 aprilie 2019 16:38:02
Problema Aho-Corasick Scor 25
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 7.17 kb
#include <fstream>
#include <cmath>
#include <cassert>
#include <iostream>
#include <cstring>
#include <cstdio>

using namespace std;

#define MAX_LEN 1000000
#define MAX_PATTERN 10000

const int SIGMA = 26;
const int ALL = 1 << 8;
int alphaTable[ALL + 1];

void init() {
  for (char c = 'a'; c <= 'z'; c++)
    alphaTable[c] = c - 'a' + 1;
}

inline int encode(char c) {
  return alphaTable[c];
}

const double PI = acos(0) * 2;

class complex {
public:
	double a, b;
	complex() {a = 0.0; b = 0.0;}
	complex(double na, double nb) {a = na; b = nb;}
	const complex operator+(const complex &c) const
		{return complex(a + c.a, b + c.b);}
  void operator += (const complex &c) {
    *this = *this + c;
  }
	const complex operator-(const complex &c) const
		{return complex(a - c.a, b - c.b);}
	const complex operator*(const complex &c) const
		{return complex(a*c.a - b*c.b, a*c.b + b*c.a);}
  void operator *= (const complex &c) {
    *this = *this * c;
  }

	double magnitude() {return sqrt(a*a+b*b);}
	void print() {
    printf("(%.3f %.3f)\n", a, b);
	}
};

class mirrorBits {
  #define MAX_SIZE 1048576
private:
  unsigned size;
  unsigned *rev;
  char* lg;

  void compute() {
    int currLg = 0, mask = 0;
    rev[1] = 1;
    lg[1] = 0;
    for (unsigned i = 2; i <= size; i++) {
      if (!(i & (i - 1))) {
        ++currLg;
        mask = (mask << 1) | 1;
      }
      // nimm die Bits nach dem Bit von log und shifte danach mit wieviel Nullen dazwischen vorliegen.
      rev[i] = (rev[mask & i] << (currLg - lg[mask & i])) | 1;
      lg[i] = currLg;
    }
    /*for (unsigned i = 0; i < 20; ++i)
      cerr << i << " : " << rev[i] << " " << (int)lg[i] << endl;
    */
  }
public:
  mirrorBits() {
    this -> size = MAX_SIZE;
    rev = new unsigned[size + 1]();
    lg = new char[size + 1]();
    compute();
  }

  // TODO: speichere falls moeglich alles in rev.
  inline int getRev(int space, int x) {
    //cerr << space << " mit " << x << " aber rev[x] = " << rev[x] << " ret = " << (rev[x] << (space - lg[x] - 1)) << endl;
    return rev[x] << (space - lg[x] - 1);
  }

  inline int getLog(int x) {
    return lg[x];
  }

  inline int getPower(int x) {
    return (!(x & (x - 1))) ? x : (1 << (lg[x] + 1));
  }
} spiegel;

class bitset {
private:
  bool color;
  int log;
  int* bits;
public:
  bitset(int log) {
    this -> log = log;
    color = true;
    bits = new int[((1 << log) >> 5) + 1]();
  }

  bool get(int x) {
    return ((bits[x >> 5] >> (x & 31)) & 1) == color;
  }

  void set(int x) {
    // clear the bit and then put the color.
    bits[x >> 5] &= ~(1 << (x & 31));
    bits[x >> 5] |= (color << (x & 31));
  }

  // change the color.
  void wechseln() {
    color = !color;
  }
};

class FFT {
public:
	complex* roots;
	int n, s;

	void initRoots(int log, int finalSize) {
	  s = log;
	  n = finalSize;
		roots = new complex[n + 1]();
		roots[0] = complex(1, 0);
		complex mult = complex(cos(2 * PI / n), sin(2 * PI / n));
		for (int i = 1; i <= n; i++)
			roots[i] = roots[i - 1] * mult;
	}

	void transform(complex* data, bool inverse = false) {
		int i, j, k;
		for (i = 1; i <= s; i++) {
			int m = (1 << i), md2 = m / 2;
			int start = 0, increment = (1 << (s-i));
			if (inverse) {
				start = n;
				increment *= -1;
			}
			complex t, u;
			for (k = 0; k < n; k += m) {
				int index = start;
				for (j = k; j < md2+k; j++) {
					t = roots[index] * data[j+md2];
					index += increment;
					data[j+md2] = data[j] - t;
					data[j] = data[j] + t;
				}
			}
		}
		if (inverse)
			for (i = 0; i < n; i++) {
				data[i].a /= n;
        data[i].b /= n;
			}
	}
} matching;

unsigned *partialSums;

unsigned square(uint32_t x) {
  return x * x;
}

// wechsele die Positionen mittels des rev-Vektors.
bool bitReverse(bitset& checked, complex *fft, unsigned finalSize, unsigned log) {
  unsigned remained = finalSize;
  for (unsigned index = 0; index < finalSize && remained; index++) {
    if (!checked.get(index)) {
      unsigned moveTo = spiegel.getRev(log, index);

      // mark them as seen. Pay attention that index and moveTo may be the same.
      checked.set(index);
      checked.set(moveTo);
      remained -= 1 + (index != moveTo);

      swap(fft[index], fft[moveTo]);
    }
  }
  checked.wechseln();
}

int getSum(int pos, uint32_t needlelen) {
  return partialSums[pos + needlelen - 1] - ((!pos) ? 0 : partialSums[pos - 1]);
}

unsigned solve(bitset& checked, complex *fftString, complex *fftPattern, unsigned textLen, unsigned finalSize, unsigned length, unsigned constPattern, unsigned log) {
  bitReverse(checked, fftPattern, finalSize, log);
  matching.transform(fftPattern);

  for (unsigned i = 0; i < finalSize; ++i)
    fftPattern[i] *= fftString[i];

  // inverse transformation
  bitReverse(checked, fftPattern, finalSize, log);
  matching.transform(fftPattern, true);

  // Shift the window at every step.
  unsigned noMatches = 0;
  for (unsigned i = 0; i < textLen - length + 1; ++i)
    noMatches += (((int)(fftPattern[i + length - 1].a + 0.5)) << 1) == constPattern + getSum(i, length);
  return noMatches;
}

inline void swap(complex& a, complex& b) {
  complex tmp = a; a = b; b = tmp;
}

// bereite das Pattern fuer FFT, gleichzeitig baue auch constPattern (the sums of all squares).
void initForFFT(complex** fft, char* str, unsigned length, unsigned finalSize, unsigned& constPattern) {
  constPattern = 0;
  for (unsigned index = 0; index < length; ++index) {
    (*fft)[index] = complex(encode(str[length - index - 1]), 0);
    constPattern += square(encode(str[index]));
  }
  for (unsigned index = length; index < finalSize; ++index)
    (*fft)[index] = complex(0, 0);
}

int main(void) {
  ifstream cin("ahocorasick.in");
  ofstream cout("ahocorasick.out");

  // Preprocessing
  init();

  // Read input.
  char *text = new char[MAX_LEN + 1];
  cin >> text;

  // Modify the size of input.
  unsigned size = strlen(text);
  unsigned log = spiegel.getLog(spiegel.getPower(size));
  unsigned finalSize = 1 << log;

  // Construct partialSums (with the square of each value) and convert the text into complex numbers.
  partialSums = new unsigned[size];
  complex* fftString = new complex[finalSize]();
  for (unsigned index = 0; index < size; index++) {
    fftString[index] = complex(encode(text[index]), 0);
    partialSums[index] = square(encode(text[index])) + ((!index) ? 0 : partialSums[index - 1]);
  }

  bitset checked(log);
  bitReverse(checked, fftString, finalSize, log);

  matching.initRoots(log, finalSize);
  matching.transform(fftString);

  // create space for pattern.
  complex* fftPattern = new complex[finalSize]();

  int Q;
  cin >> Q;
  while (Q--) {
    cin >> text;
    unsigned length = strlen(text);
    if (length > size) {
      cout << 0 << '\n';
      continue;
    }
    unsigned constPattern;
    initForFFT(&fftPattern, text, length, finalSize, constPattern);
    //for (unsigned i = 0; i < finalSize; i++) fftPattern[i].print();
    cout << solve(checked, fftString, fftPattern, size, finalSize, length, constPattern, log) << '\n';
  }
  return 0;
}