Cod sursa(job #1072113)

Utilizator blasterzMircea Dima blasterz Data 3 ianuarie 2014 23:17:19
Problema PScPld Scor 0
Compilator cpp Status done
Runda Arhiva de probleme Marime 2.56 kb
#include <algorithm>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdio>
using namespace std;

ifstream fin("pscpld.in");
ofstream fout("pscpld.out");

#define ll long long
#define nmax 1000005
#define mod 10007
#define mod2 200003
//666013

int i, j, n;
int pw, cnt, save, Ans;

char a[nmax];

int h[nmax];
int rh[nmax];

int Pw[nmax], Pw2[nmax];
int inv[mod + 1], inv2[mod2 + 1];

inline int lgput(int a, int b, int md) {
    int ans = 1;
    for (int p = a; b; b >>= 1) {
        if (b & 1) ans = int((1LL * ans * p) % md);
        p = int((1LL * p * p) % md);
    }
    return (ans % md);
}

inline int Hash(int j, int i, int dv, int h[]) {
    int s = ((h[i] - h[j]) % mod + mod) % mod;
    int div = Pw[dv];
    div = inv[div];
    int hv = int((1LL * s * div) % mod);
    return hv;
}

inline int Hash2(int j, int i, int dv, int h[]) {
    int s = ((h[i] - h[j]) % mod2 + mod2) % mod2;
    int div = Pw2[dv];
    div = inv2[div];
    int hv = int((1LL * s * div) % mod2);
    return hv;
}

inline int Found (int L) {
    if (j - L <= 0) return 0;
    if (j + L > n) return 0;
    if (Hash(j - L - 1, j + L, j - L - 1, h) == Hash(j + L + 1, j - L, n - (j + L),  rh)) {
        return 1;
    }

    return 0;
}

inline int Found2(int L) {
    if (j - L <= 0 || j + L > n) return 0;
    if (Hash(j - L, j + L, j - L, h) == Hash(j + L + 1, j - L + 1, n - (j + L), rh))
    {
        return 1;
    }
    return 0;
}

int main() {
    fin >> (a + 1);
    n = strlen(a + 1);
    pw = 1;
    for (i = 1; i <= n; ++i) {
        h[i] = (h[i - 1] + pw * a[i]) % mod;
        pw = (pw * 257) % mod;
    }
    pw = 1;
    for (i = n; i >= 0; --i) {
        rh[i] = (rh[i + 1] + pw * a[i]) % mod;
        pw = (pw * 257) % mod;
    }
    /*for (i = 1; i <= n; ++i)
        printf ("%d ", h[i]);
    printf ("\n");
    for (i = 1; i <= n; ++i)
        printf ("%d ", rh[i]);
    printf ("\n");
    */
    Pw[0] = 1;
    Pw[1] = 257;
    for (i = 2; i <= n; ++i)
        Pw[i] = (Pw[i - 1] * 257) % mod;

    for (i = 0; i < mod; ++i)
        inv[i] = lgput(i, mod - 2, mod);
    
    for (cnt = 1; cnt <= n; cnt <<= 1);
    save = cnt;
    for (j = 2; j < n; ++j) {
        cnt = save;
        for (i = 0; cnt; cnt >>= 1) {
            if (i + cnt <= n && Found(i + cnt))
                i += cnt;
        }
        Ans += i;
        //printf ("%d: %d\n", j, i);
        
        cnt = save;
        for (i = 0; cnt; cnt >>= 1)
            if (i + cnt <= n && Found2(i + cnt))
                i += cnt;
        Ans += i;
    
    }
    fout << Ans + n << '\n';
    //printf ("%d\n", Ans + n);
    return 0;
}