Cod sursa(job #1282647)

Utilizator serban_ioan97Ciofu Serban serban_ioan97 Data 4 decembrie 2014 16:45:57
Problema Salvare Scor 100
Compilator cpp Status done
Runda Arhiva de probleme Marime 2.27 kb
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define nmax 1010
#define pb push_back

using namespace std;

int n, k, d, sum, sol, l;
vector <int> a[nmax];
int g[nmax], used[nmax];
int c[nmax], c2[nmax], v[nmax], p[nmax];
int rez[nmax], s[nmax];

void dfs(int nod)
{
    int i, found=0;
    l++;
    used[nod]=l;

    for(i=0; i<g[nod]; i++)
    if(!used[a[nod][i]])
    {
        found=1;
        dfs(a[nod][i]);

        if (c[a[nod][i]]+1>c[nod])
        {
            c2[nod]=c[nod];
            c[nod]=c[a[nod][i]]+1;
            p[nod]=a[nod][i];
        }
        else if (c[a[nod][i]]+1>c2[nod]) c2[nod]=c[a[nod][i]]+1;

        v[nod]|=v[a[nod][i]];
    }

    for(i=0; i<g[nod]; i++)
    if((used[a[nod][i]]>used[nod]) && (v[a[nod][i]]))
    {
        if ((a[nod][i]!=p[nod]))
        {
            if ((c[nod]>c[a[nod][i]]+1) && (c[nod]+c[a[nod][i]]<=2*d)) c[nod]=c[a[nod][i]]+1;
        }
        else if ((c[nod]>c[a[nod][i]]+1) && (c2[nod]+c[a[nod][i]]<=2*d)) c[nod]=c[a[nod][i]]+1;
    }

    if (!found) c[nod]=d+1;

    if (c[nod]==2*d+1)
    {
        c[nod]=0;
        v[nod]=1;
        s[++sum]=nod;
    }

    l--;
}

int okay()
{
    sum=0;

    memset(used, 0, sizeof(used));
    memset(v, 0, sizeof(v));
    memset(c, 0, sizeof(c));
    memset(c2, 0, sizeof(c2));

    int i, j;

    dfs(1);
    if(c[1]>d)
    {
        s[++sum]=1;
        c[1]=0;
    }

    j=1;
    for(i=sum+1; i<=k; i++)
    {
        for (; c[j]==0; j++);
        s[i]=j;
        c[j]=0;
    }

    return sum;
}

int main()
{
    freopen("salvare.in", "rt", stdin);
    freopen("salvare.out","wt", stdout);

    scanf("%d %d ", &n, &k);

    int front=0, back=n, i, x, y;

    for (i=1; i<n; i++)
    {
        scanf("%d %d ", &x, &y);
        a[x].pb(y);
        a[y].pb(x);
    }

    for (i=1; i<=n; i++) g[i]=a[i].size();

    while(front<=back)
    {
        d=(front+back)/2;

        if(okay()<=k)
        {
            sol=d;
            memcpy(rez, s, sizeof(s));
            back=d-1;
        }
        else front=d+1;
    }

    sort(rez+1, rez+k+1);

    printf("%d\n", sol);

    for(i=1; i<=k; i++) printf("%d ", rez[i]);
    printf("\n");

    return 0;
}