Cod sursa(job #2916093)

Utilizator nstefanNeagu Stefan nstefan Data 28 iulie 2022 00:30:54
Problema Arbori indexati binar Scor 100
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 2.1 kb
#include <bits/stdc++.h>
using namespace std;

ifstream fin("aib.in");
ofstream fout("aib.out");
int n, m;
vector<int> v(100002);

template <typename T>
class BIT {
private:
    vector<T> m_bit;
    size_t m_bitSize;

public:
    // Creates a BIT based on copies of the elements from [first,last).
    BIT(typename std::vector<T>::iterator first, typename std::vector<T>::iterator last)
    {
        m_bitSize = last - first;
        m_bit = vector<T>(m_bitSize + 1);
        for (size_t index = 1; first != last; first++, index++) {
            m_bit[index] = *first;
        }

        for (size_t i = 1; i <= m_bitSize; i++) {
            size_t p = i + (i & -i);
            if (p <= m_bitSize)
                m_bit[p] += m_bit[i];
        }
    }

    // Returns the sum of all elements until pos (including pos).
    int sum(size_t pos)
    {
        int sum = 0;
        while (pos >= 1) {
            sum += m_bit[pos];
            pos -= pos & -pos;
        }
        return sum;
    }

    // Adds val to the element on index pos.
    void add(size_t pos, int val)
    {
        while (pos <= m_bitSize) {
            m_bit[pos] += val;
            pos += pos & -pos;
        }
    }

public:
    int size()
    {
        return m_bitSize;
    }
};

int main()
{
    fin >> n >> m;
    for (int i = 1; i <= n; i++) {
        fin >> v[i];
    }
    BIT<int> bit = BIT<int>(v.begin() + 1, v.begin() + n + 1);

    // answer queries
    while (m--) {
        int c;
        fin >> c;
        if (c == 0) {
            int a, b;
            fin >> a >> b;
            bit.add(a, b);
        }
        if (c == 1) {
            int a, b;
            fin >> a >> b;
            fout << bit.sum(b) - bit.sum(a - 1) << '\n';
        }
        if (c == 2) {
            int a;
            fin >> a;
            int power = 1;
            for (; power <= n; power <<= 1)
                ;
            int index = 1;
            for (; power; power >>= 1)
                if (index + power <= n and bit.sum(index + power) <= a)
                    index += power;
            fout << (bit.sum(index) == a ? index : -1) << '\n';
        }
    }
}