Cod sursa(job #50042)

Utilizator DastasIonescu Vlad Dastas Data 6 aprilie 2007 19:14:26
Problema Elimin Scor 50
Compilator cpp Status done
Runda Arhiva de probleme Marime 2.49 kb
#include <stdio.h>
#include <iostream>

FILE *in = fopen("elimin.in","r"), *out = fopen("elimin.out","w");

int n, m, r, c;
int a[600][600];
int eliminate[600] = {0};

int ss = 0;

int poz(int p, int u, int b[])
{
	int s=p, d=u, x=b[p];
	while ( s < d )
	{
		while ( b[d] >= x && s < d )
			--d;
		b[s] = b[d];
		while ( b[s] <= x && s < d )
			++s;
		b[d] = b[s];
	}
	b[s] = x;
	return s;
}

void qs(int p, int u, int b[])
{
	int m = poz(p, u, b);
	if ( p < m )
		qs(p, m-1, b);
	if ( m < u )
		qs(m+1, u, b);
}

void read()
{
    fscanf(in, "%d %d %d %d", &m, &n, &r, &c);

    for ( int i = 1; i <= m; ++i )
        for ( int j = 1; j <= n; ++j )
            fscanf(in, "%d" , &a[i][j]);
}

void suma()
{
    int t = 0;
    int b[600] = {0};

    for ( int i = 1; i <= m; ++i )
        for ( int j = 1; j <= n; ++j )
        {
            if ( eliminate[j] == 0 )
                b[i] += a[i][j];
        }

    qs(1, m, b);

    for ( int i = r+1; i <= m; ++i )
        t += b[i];

    if ( t > ss )
        ss = t;

    memset(eliminate, 0, sizeof(eliminate));
    //memset(b, 0, sizeof(b));

}

void suma2()
{
    int t = 0;
    int b[600] = {0};

    for ( int i = 1; i <= m; ++i )
        for ( int j = 1; j <= n; ++j )
        {
            if ( eliminate[i] == 0 )
                b[j] += a[i][j];
        }

    qs(1, n, b);

    for ( int i = c+1; i <= n; ++i )
        t += b[i];

    if ( t > ss )
        ss = t;

    memset(eliminate, 0, sizeof(eliminate));
    //memset(b, 0, sizeof(b));

}

int st[600];

void back(int col)
{
    for ( int i = st[col-1]+1; i <= n; ++i )
    {
        st[col] = i;
        if ( col == c )
        {
            for ( int t = 1; t <= c; ++t )
                eliminate[st[t]] = 1;

            suma();
        }
        else
            back(col+1);
    }
}

void back2(int col)
{
    for ( int i = st[col-1]+1; i <= m; ++i )
    {
        st[col] = i;
        if ( col == r )
        {
            for ( int t = 1; t <= r; ++t )
                eliminate[st[t]] = 1;

            suma2();
        }
        else
            back(col+1);
    }
}

int main()
{
    read();
    memset(st, 0, sizeof(st));

    int p = 1, q = 1;

    for ( int i = m-r+1; i <= m; ++i )
        p *= i;
    for ( int i = n-c+1; i <= n; ++i )
        q *= i;

    if ( p < q )
        back2(1);
    else
        back(1);


    fprintf(out, "%d\n", ss);

	return 0;
}