Cod sursa(job #2019735)

Utilizator antanaAntonia Boca antana Data 8 septembrie 2017 13:43:38
Problema Guvern Scor 100
Compilator cpp Status done
Runda Arhiva de probleme Marime 2.25 kb
#include <bits/stdc++.h>

using namespace std;

const int MAXN = 2e5 + 1;
const int INF  = 0x3f3f3f3f;

FILE *fin, *fout;

vector < int > G[MAXN], paths[MAXN];

int n, k, val[MAXN], d[MAXN], first[MAXN], last[MAXN], best[MAXN], lg[MAXN];
bool seen[MAXN];

struct Sons {
    int x, y, dp;

    bool operator < (const Sons &aux) const {
        return y < aux.y;
    }
} v[MAXN];

class compare {
public:
    bool operator () (const int &a, const int &b) const {
        return val[ a ] < val[ b ];
    }
};

set < int, compare > srt;
set < int, compare >::iterator it;

void dfs( int node ) {
    seen[ node ] = 1;
    first[ node ] = ++k;

    it = srt.lower_bound( node );
    if (it != srt.end())
        paths[ *it ].push_back( node );
    else
        paths[ 0 ].push_back( node );

    srt.insert( node );
    for (int son: G[ node ])
        if (!seen[ son ])
            dfs( son );
    srt.erase( node );

    last[ node ] = k;
}

void compute( int node ) {

    seen[ node ] = 1;
    for (int son: paths[ node ])
        if (!seen[ son ])
            compute( son );

    k = 0;
    for (int son: paths[ node ])
        v[ ++k ] = { first[ son ], last[ son ], d[ son ] };

    sort( v+1, v+k+1 );

    for (int i=1; i<=k; i++) {
        int ans = 0;
        for (int step = (1<<lg[ k ]); step >= 1; step >>= 1)
            if (ans + step <= k && v[ ans+step ].y < v[ i ].x)
                ans += step;
        best[ i ] = max( best[ i-1 ], best[ ans ] + v[ i ].dp );
    }

    d[ node ] = 1 + best[ k ];
}

int main()
{
    fin = fopen( "guvern.in", "r");
    fout= fopen( "guvern.out","w");

    int x, y;

    fscanf(fin, "%d", &n);
    for (int i = 1; i < n; i++) {
        fscanf(fin, "%d%d", &x, &y);
        G[ x ].push_back( y );
        G[ y ].push_back( x );
    }

    for (int i = 1; i <= n; i++)
        fscanf(fin, "%d", &val[ i ]);

    lg[ 2 ] = 1;
    for (int i = 3; i <= n; i++)
        lg[ i ] = lg[ i/2 ] + 1;

    dfs( 1 );
    memset( seen, 0, sizeof seen );
    compute( 0 );

    int ans = d[ 0 ] - 1;
    for (int i = 1; i <= n; i++)
        ans = max( ans, d[ i ] );

    fprintf(fout, "%d", ans);

    fclose( fin );
    fclose( fout );

    return 0;
}