Cod sursa(job #586284)

Utilizator GavrilaVladGavrila Vlad GavrilaVlad Data 30 aprilie 2011 14:22:16
Problema Guvern Scor 70
Compilator cpp Status done
Runda Algoritmiada 2011, Runda Finală, Clasele 10-12 Marime 2.04 kb
#include <stdio.h>
#include <vector>
#include <set>

using namespace std;

#define maxn 200010
#define mlog 19

int n, i, j, k, a, b, el, rez, coop[maxn], f[maxn], niv[maxn], d[maxn], sv[maxn];
vector<int> v[maxn], prov[maxn];
set<pair<int, int> > g;
int st[mlog][maxn];

void df(int nod, int tata)
{
    if(f[nod]==1)
        return;
    f[nod]=1;

    niv[nod]=niv[tata]+1;
    st[0][nod]=tata;
    for(int i=1; st[i-1][st[i-1][nod]]>0; ++i)
        st[i][nod]=st[i-1][st[i-1][nod]];

    g.insert(make_pair(coop[nod], nod));

    for(int i=0; i<v[nod].size(); ++i)
        df(v[nod][i], nod);

    set<pair<int, int> > ::iterator it=g.upper_bound(make_pair(coop[nod], nod));
    if(it!=g.end())
        prov[it->second].push_back(nod);

    g.erase(make_pair(coop[nod], nod));

    sv[0]=0;
    int sol, left, right;
    long long sum=0;

    for(int i=0; i<prov[nod].size(); ++i)
    {
        left=1;
        right=sv[0];
        sol=0;
        while(left<=right)
        {
            int med=(left+right)/2;
            int nc=sv[med];

            for(int j=mlog-1; j>=0; --j)
            {
                if(st[j][nc]==0)
                    continue;
                if(niv[st[j][nc]]>=niv[prov[nod][i]])
                    nc=st[j][nc];
            }

            if(nc!=prov[nod][i])
            {
                sol=med;
                left=med+1;
            }
            else
                right=med-1;
        }

        sum=0;
        for(int j=sol+1; j<=sv[0]; ++j)
            sum+=d[sv[j]];

        if(sum<d[prov[nod][i]])
        {
            sv[0]=sol+1;
            sv[sv[0]]=prov[nod][i];
        }
    }

    for(int i=1; i<=sv[0]; ++i)
        d[nod]+=d[sv[i]];
    ++d[nod];

    rez=max(rez, d[nod]);
}

int main()
{
    freopen("guvern.in", "r", stdin);
    freopen("guvern.out", "w", stdout);

    scanf("%d", &n);
    for(int i=1; i<n; ++i)
    {
        scanf("%d%d", &a, &b);
        v[a].push_back(b);
        v[b].push_back(a);
    }
    for(int i=1; i<=n; ++i)
        scanf("%d", &coop[i]);

    df(1, 0);

    printf("%d\n", rez);
    return 0;
}