Cod sursa(job #1072104)

Utilizator blasterzMircea Dima blasterz Data 3 ianuarie 2014 23:00:33
Problema PScPld Scor 0
Compilator cpp Status done
Runda Arhiva de probleme Marime 2.38 kb
#include <algorithm>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdio>
using namespace std;

ifstream fin("pscpld.in");
ofstream fout("pscpld.out");

#define ll long long
#define nmax 1000005
#define mod 13
//666013

int i, j, n;
int pw, cnt, save, Ans;

char a[nmax];

int h[nmax];
int rh[nmax];

inline int lgput(int a, int b) {
    int ans = 1;
    for (int p = a; b; b >>= 1) {
        if (b & 1) ans = int((1LL * ans * p) % mod);
        p = int((1LL * p * p) % mod);
    }
    return (ans % mod);
}

inline int Hash(int j, int i, int dv, int h[]) {
    int s = ((h[i] - h[j]) % mod + mod) % mod;
    //printf ("%d %d: (%d %d %d)\n", j - 1, i, h[i], h[j - 1], s);
    int div = lgput(257, dv) ;
    div = lgput(div, mod - 2);
    int hv = int((1LL * s * div) % mod);
    return hv;
}

inline int Found (int L) {
    if (j - L <= 0) return 0;
    if (j + L > n) return 0;
    //printf ("j: %d, L : %d\n", j, L);
    //printf ("(%d %d, %d %d--%d) %d %d\n", j - L - 1, j + L, j + L + 1, j - L, n - (j + L), hash(j - L - 1, j + L,j - L - 1,  h), hash(j + L + 1, j - L, n - (j + L),  rh));
    if (Hash(j - L - 1, j + L, j - L - 1, h) == Hash(j + L + 1, j - L, n - (j + L),  rh))
        return 1;
    

    return 0;
}

inline int Found2(int L) {
    if (j - L <= 0 || j + L > n) return 0;
    if (Hash(j - L, j + L, j - L, h) == Hash(j + L + 1, j - L + 1, n - (j + L), rh))
        return 1;
    return 0;
}

int main() {
    fin >> (a + 1);
    n = strlen(a + 1);
    pw = 1;
    for (i = 1; i <= n; ++i) {
        h[i] = (h[i - 1] + pw * a[i]) % mod;
        pw = (pw * 257) % mod;
    }
    pw = 1;
    for (i = n; i >= 0; --i) {
        rh[i] = (rh[i + 1] + pw * a[i]) % mod;
        pw = (pw * 257) % mod;
    }
    /*for (i = 1; i <= n; ++i)
        printf ("%d ", h[i]);
    printf ("\n");
    for (i = 1; i <= n; ++i)
        printf ("%d ", rh[i]);
    printf ("\n");
    */
    for (cnt = 1; cnt <= n; cnt <<= 1);
    save = cnt;
    for (j = 2; j < n; ++j) {
        cnt = save;
        for (i = 0; cnt; cnt >>= 1) {
            if (i + cnt <= n && Found(i + cnt))
                i += cnt;
        }
        Ans += i;
        //printf ("%d: %d\n", j, i);
        
        cnt = save;
        for (i = 0; cnt; cnt >>= 1)
            if (i + cnt <= n && Found2(i + cnt))
                i += cnt;
        Ans += i;
    
    }
    fout << Ans + n << '\n';
    //printf ("%d\n", Ans + n);
    return 0;
}