Cod sursa(job #2727417)

Utilizator EckchartZgarcea Robert-Andrei Eckchart Data 21 martie 2021 20:45:17
Problema Suma divizorilor Scor 50
Compilator cpp-64 Status done
Runda Arhiva de probleme Marime 2.95 kb
#include "bits/stdc++.h"
#include <cassert>
using namespace std;
using ll = long long;
using ull = unsigned long long;
using ld = long double;
using pi = pair<int, int>;
using pl = pair<ll, ll>;
using pd = pair<double, double>;
using pld = pair<ld, ld>;
const int M = 9901, MAX_SQRT_NR = 7071;  // 7071 = floor(sqrt(5e7)).


int main()
{
    bitset<MAX_SQRT_NR + 1> is_prime = move(bitset<MAX_SQRT_NR + 1>{}.set());
    vector<int> primes{2};
    auto eratosthenes = [&]() -> void
    {
        is_prime[0] = is_prime[1] = false;
        for (int j = 4; j <= MAX_SQRT_NR; j += 2)
        {
            is_prime[j] = false;
        }

        int i;
        for (i = 3; i * i <= MAX_SQRT_NR; i += 2)
        {
            if (is_prime[i])
            {
                primes.emplace_back(i);
                for (int j = i * i; j <= MAX_SQRT_NR; j += i)
                {
                    is_prime[j] = false;
                }
            }
        }

        for (i += !(i & 1); i <= MAX_SQRT_NR; i += 2)
        {
            if (is_prime[i])
            {
                primes.emplace_back(i);
            }
        }
    };
    
    auto mod = [](const ll a) -> ll
    {
        return (a % M + M) % M;
    };
    
    vector<int> mod_inverse(M, 1);
    auto precalc_mod_inverse = [&]() -> void
    {
        for (int i = 2; i < M; ++i)
        {
            mod_inverse[i] = mod(-(M / i) * mod_inverse[M % i]);
        }
    };
    
    eratosthenes();
    precalc_mod_inverse();
    
    ifstream cin("sumdiv.in");
    ofstream cout("sumdiv.out");

    int A, B;
    cin >> A >> B;

    if (A == 0 && B == 0)
    {
        cout << "1";
        return 0;
    }
    if (A == 0)
    {
        cout << "0";
        return 0;
    }
    if (B == 0)
    {
        cout << "1";
        return 0;
    }

    auto fast_exp = [&](ll a, int b) -> ll
    {
        ll res{1};
        while (b)
        {
            if (b & 1)
            {
                res = mod(res * a);
            }
            a = mod(a * a);
            b /= 2;
        }
        return res;
    };

    ll res{1};
    for (const int prime : primes)
    {
        if (prime * prime > A)
        {
            break;
        }

        int exp{};
        while (A % prime == 0)
        {
            A /= prime;
            ++exp;
        }
        if (exp)
        {
            exp *= B;
            if (mod(prime - 1))
            {
                res = mod(res * mod((fast_exp(prime, exp + 1) - 1) * mod_inverse[mod(prime - 1)]));
            }
            else
            {
                res = mod(res * (exp + 1));
            }
        }
    }
    if (A > 1)
    {
        if (mod(A - 1))
        {
            res = mod(res * mod((fast_exp(A, B + 1) - 1) * mod_inverse[(B - 1) % M]));
        }
        else
        {
            res = mod(res * (B + 1));
        }
    }

    cout << res;
}