FenwickTree

Reference: Algorithms for Competitive Programming: Fenwick Tree

Prefix Sum

Python

class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * n

    def sum(self, r, l=None):
        if l is not None:
            return self.sum(l) - self.sum(r - 1)
        ret = 0
        while r >= 0:
            ret += self.bit[r]
            r = (r & (r + 1)) - 1
        return ret

    def add(self, idx, delta):
        while idx < self.n:
            self.bit[idx] += delta
            idx = idx | (idx + 1)

f = FenwickTree(10)
assert f.sum(0, 9) == 0
f.add(9, 1)
assert f.sum(9) == 1
assert f.sum(8) == 0
assert f.sum(0, 9) == 1
f.add(8, -1)
assert f.sum(0, 9) == 0
assert f.sum(0, 8) == -1

C++

struct FenwickTree {
    vector<int> bit;  // binary indexed tree
    int n;

    FenwickTree(int n) {
        this->n = n;
        bit.assign(n, 0);
    }

    FenwickTree(vector<int> a) : FenwickTree(a.size()) {
        for (size_t i = 0; i < a.size(); i++)
            add(i, a[i]);
    }

    int sum(int r) {
        int ret = 0;
        for (; r >= 0; r = (r & (r + 1)) - 1)
            ret += bit[r];
        return ret;
    }

    int sum(int l, int r) {
        return sum(r) - sum(l - 1);
    }

    void add(int idx, int delta) {
        for (; idx < n; idx = idx | (idx + 1))
            bit[idx] += delta;
    }
};

Problem: Count number of interval intersections

Min/Max

class FenwickTreeMin:
    def __init__(self, n):
        self.n = n
        self.inf = 10**9
        self.bit = [self.inf] * n

    def getmin(self, r):
        ret = self.inf
        while r >= 0:
            ret = min(ret, self.bit[r])
            r = (r & (r + 1)) - 1
        return ret

    def update(self, idx, val):
        while idx < self.n:
            self.bit[idx] = min(self.bit[idx], val)
            idx = idx | (idx + 1)


class FenwickTreeMax:
    def __init__(self, n):
        self.n = n
        self.inf = -(10**9)
        self.bit = [self.inf] * n

    def getmax(self, r):
        ret = self.inf
        while r >= 0:
            ret = max(ret, self.bit[r])
            r = (r & (r + 1)) - 1
        return ret

    def update(self, idx, val):
        while idx < self.n:
            self.bit[idx] = max(self.bit[idx], val)
            idx = idx | (idx + 1)

Cited by 3