Cod sursa(job #592226)

Utilizator iepurasu_pasteGeorge Macovei iepurasu_paste Data 27 mai 2011 10:20:27
Problema Xor Max Scor 0
Compilator cpp Status done
Runda Arhiva de probleme Marime 3.06 kb
#include<stdio.h>
#include<vector>
using namespace std;

#define pi pair<int,int>
#define x first
#define y second
#define pb push_back
#define NMAX 100005

int t,n,k,l,nr;
char viz[NMAX];
int AIB[2*NMAX];
int d[NMAX][21];
int niv[NMAX];
vector<int> v[NMAX],lc[NMAX];
pi ter[NMAX],ren[NMAX];

void update(int poz,int val)
{
    for(;poz<=2*n;poz+=(poz^(poz-1)&poz))
        AIB[poz]+=val;
}

int query(int poz)
{
    int sum=0;
    for(;poz;poz-=(poz^(poz-1)&poz))
        sum+=AIB[poz];
    return sum;
}

void dfs(int nod)
{
    int i,vec,lim=v[nod].size();
    viz[nod]=1;
    for(i=1;i<=l;i++)
        d[nod][i]=d[d[nod][i-1]][i-1];
    ren[nod].x=++nr;
    for(i=0;i<lim;i++)
        if(!viz[vec=v[nod][i]])
        {
            d[vec][0]=nod;
            niv[vec]=niv[nod]+1;
            dfs(vec);
        }
    ren[nod].y=++nr;
}

int find(int nod,int val,int put)
{
    if(!val)
        return nod;
    if((1<<put)>val)
        return find(nod,val,put-1);
    return find(d[nod][put],val-(1<<put),put-1);
}

int lca(int nod1,int nod2)
{
    int i;
    if(niv[nod1]>niv[nod2])
        nod1=find(nod1,niv[nod1]-niv[nod2],l);
    else if(niv[nod1]<niv[nod2])
        nod2=find(nod2,niv[nod2]-niv[nod1],l);
    for(i=l;i>=0;i--)
        if(d[nod1][i]!=d[nod2][i])
        {
            nod1=d[nod1][i];
            nod2=d[nod2][i];
        }
    if(nod1==nod2)
        return nod1;
    return d[nod1][0];
}

int dfs_last(int nod)
{
    int lim=v[nod].size(),lim2=lc[nod].size();
    int i,val1,val2,sol=0,vec,ind;
    viz[nod]=1;
    for(i=0;i<lim;i++)
        if(!viz[vec=v[nod][i]])
        {
            ind=dfs_last(vec);
            sol+=ind;
        }
    for(i=0;i<lim2;i++)
    {
        ind=lc[nod][i];
        val1=query(ren[ter[ind].x].x)-query(ren[nod].x);
        val2=query(ren[ter[ind].y].x)-query(ren[nod].x);
        if(!val1 && !val2)
            break;
    }
    if(i<lim2)
    {
        sol++;
        update(ren[nod].x,1);
        update(ren[nod].y,-1);
    }
    return sol;
}

int main ()
{
    int i,a,b,val,sol;
    
    freopen("ct.in","r",stdin);
    freopen("ct.out","w",stdout);
    scanf("%d",&t);
    for(;t;t--)
    {
        memset(viz,0,sizeof(viz));
        memset(AIB,0,sizeof(AIB));
        memset(niv,0,sizeof(niv));
        memset(d,0,sizeof(d));
        memset(ren,0,sizeof(ren));
        memset(ter,0,sizeof(ter));

        scanf("%d%d",&n,&k);
        for(l=0;(1<<l)<=n;l++);l--;
        
        for(i=1;i<n;i++)
        {
            scanf("%d%d",&a,&b);
            v[a].pb(b);
            v[b].pb(a);
        }
        nr=0;
        dfs(1);
        for(i=1;i<=k;i++)
        {
            scanf("%d%d",&ter[i].x,&ter[i].y);
            val=lca(ter[i].x,ter[i].y);
            lc[val].pb(i);
        }
        memset(viz,0,sizeof(viz));
        sol=dfs_last(1);
        printf("%d\n",sol);
        for(i=0;i<=n;i++)
            v[i].clear();
        for(i=0;i<=n;i++)
            lc[i].clear();
    }
    
    return 0;
}