Segment Tree

TODO Write explanation

class SegTree:
    def __init__(self, n, op):
        self.n = n
        self.op = op
        self.tree = [None] * (4 * n)

    def update(self, pos, val):
        def rec(i, l, r):
            if l == r:
                self.tree[i] = val
                return

            m = (l + r) // 2
            if pos <= m:
                rec(i * 2, l, m)
            else:
                rec(i * 2 + 1, m + 1, r)

            self.tree[i] = self._op(self.tree[i * 2], self.tree[i * 2 + 1])

        rec(1, 0, self.n)
        return self.tree[1]

    def query(self, ql, qr):
        def rec(i, l, r):
            intersect = l <= ql <= r or l <= qr <= r
            if l > r or not intersect:
                return None
            if ql <= l <= r <= qr:
                return self.tree[i]
            m = (l + r) // 2
            return self._op(rec(i * 2, l, m), rec(i * 2 + 1, m + 1, r))

        return rec(1, 0, self.n)

    def _op(self, a, b):
        if a is not None and b is not None:
            return self.op(a, b)
        if a is not None:
            return a
        return b


st1 = SegTree(3, lambda a, b: a + b)
assert st1.update(3, 1) == 1
assert st1.query(0, 3) == 1
assert st1.update(1, 2) == 3
assert st1.query(0, 2) == 2
assert st1.update(0, 1) == 4
assert st1.query(0, 3) == 4
assert st1.query(0, 0) == 1


st2 = SegTree(3, min)
assert st2.update(3, 2) == 2
assert st2.query(3, 3) == 2
assert st2.update(0, 3) == 2
assert st2.query(0, 2) == 3
assert st2.update(1, 1) == 1
assert st2.query(0, 2) == 1

Lazy Segment Tree

Lazy Max Segment Tree

class SegTree:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.left = None
        self.right = None
        self.value = float("-inf")

    def set(self, pos, value):
        if pos < self.start or self.end < pos:
            return
        if self.start == self.end:
            self.value = value
            return
        self._extend()
        self.left.set(pos, value)
        self.right.set(pos, value)
        self.value = max(self.left.value, self.right.value)

    def query(self, start, end):
        if end < self.start or self.end < start:
            return float("-inf")
        if start <= self.start and self.end <= end:
            return self.value
        self._extend()
        return max(self.left.query(start, end), self.right.query(start, end))

    def _extend(self):
        if self.left is None and self.start < self.end:
            m = self.start + (self.end - self.start) // 2
            self.left = SegTree(self.start, m)
            self.right = SegTree(m + 1, self.end)

    def print(self, level=0):
        print(" " * level, "SegTree", (self.start, self.end), self.value)
        if self.left:
            self.left.print(level + 1)
            self.right.print(level + 1)

# sg = SegTree(0, 10)
# sg.set(5, 2)
# sg.set(1, 10)
# assert sg.query(1, 3) == 10
# assert sg.query(2, 5) == 2
# sg.print()

Cited by 3