Cod sursa(job #1538784)

Utilizator algebristulFilip Berila algebristul Data 29 noiembrie 2015 19:36:07
Problema Cowfood Scor 28
Compilator cpp Status done
Runda Arhiva de probleme Marime 3.08 kb
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>

#define FILEIN "cowfood.in"
#define FILEOUT "cowfood.out"

using namespace std;

int k, S, n;
int a[20][31];
int One[32];
bool Zero = false;
long long sums[10005];
const int MOD = 3210121;

bool all_zeroes(int values[]) {
    for (int i = 1; i <= k; i++) {
        if (values[i])
            return false;
    }
    return true;
}

int one_not_zero(int values[]) {
    int cnt = 0;
    for (int i = 1; i <= k; i++) {
        if (values[i]) {
            cnt++;
        }
    }

    if (cnt != 1) {
        return k+1;
    }

    for (int i = 1; i <= k; i++) {
        if (values[i])
            return i;
    }

    return k+1;
}

long long explog(long long x, int pow) {
    long long t = x;
    long long sol = 1;

    while (pow) {
        if (pow & 1) {
            sol = sol * t;
            sol %= MOD;
        }

        t = t * t;
        t %= MOD;
        pow >>= 1;
    }

    return sol;
}

void pre() {
    // sums[i] = combinari de i + k - 1 luate cate k - 1
    // = (i+k - 1)! / (i)! * (k - 1)!

    long long kfact = 1;
    for (int i = 2; i < k; i++) {
        kfact *= i;
        kfact %= MOD;
    }

    // invmod
    kfact = explog(kfact, MOD - 2);

    long long fact_upper = 1;
    long long fact_lower = 1;

    for (int i = 0; i <= S; i++) {
        fact_upper *= (k + i - 1);
        fact_upper %= MOD;
        fact_lower *= (i > 0) ? i : 1;
        fact_lower %= MOD;

        sums[i] = kfact;
        sums[i] *= explog(fact_lower, MOD - 2);
        sums[i] %= MOD;
        sums[i] *= fact_upper;
        sums[i] %= MOD;
    }

    for (int i = 1; i <= S; i++) {
        sums[i] += sums[i-1];
        sums[i] %= MOD;
    }
}

int main() {
    freopen(FILEIN, "r", stdin);
    freopen(FILEOUT, "w", stdout);

    scanf("%d %d %d", &k, &S, &n);
    for (int i = 0; i < n; i++) {
        for (int j = 1; j <= k; j++) {
            scanf("%d", &a[i][j]);
        }
    }

    pre();

    for (int i = 1; i <= k; i++)
        One[i] = S + 1;

    long long ans = 0;
    for (int m = 0; m < (1<<n); m++) {
        int sign = 1;
        int values[31];
        memset(values, 0, sizeof(values));
        for (int i = 0; i < n; i++) {
            if (m & (1<<i)) {
                sign = -sign;
                for (int j = 1; j <= k; j++) {
                    values[j] = max(values[j], a[i][j]);
                }
            }
        }

        if (all_zeroes(values) && m) {
            Zero = true;
            for (int i = 1; i <= k; i++) {
                One[i] = 0;
            }
        }
        if (int tmp = one_not_zero(values)) {
            One[tmp] = min(One[tmp], values[tmp]);
        }

        long long sum = 0;
        for (int j = 1; j <= k; j++) {
            sum += values[j];
        }

        if (sum > S)
            continue;

        ans += 1LL * sign * sums[S - sum];
        ans %= MOD;
    }

    if (!Zero)
        ans--;
    for (int i = 1; i <= k; i++) {
        cerr << One[i] << '\n';
        ans -= One[i] - 1;
    }
    ans %= MOD;
    if (ans < 0)
        ans += MOD;

    printf("%lld\n", ans);

    return 0;
}