Segment Tree
TODO Write explanation
class SegTree: def __init__(self, n, op): self.n = n self.op = op self.tree = [None] * (4 * n) def update(self, pos, val): def rec(i, l, r): if l == r: self.tree[i] = val return m = (l + r) // 2 if pos <= m: rec(i * 2, l, m) else: rec(i * 2 + 1, m + 1, r) self.tree[i] = self._op(self.tree[i * 2], self.tree[i * 2 + 1]) rec(1, 0, self.n) return self.tree[1] def query(self, ql, qr): def rec(i, l, r): intersect = l <= ql <= r or l <= qr <= r if l > r or not intersect: return None if ql <= l <= r <= qr: return self.tree[i] m = (l + r) // 2 return self._op(rec(i * 2, l, m), rec(i * 2 + 1, m + 1, r)) return rec(1, 0, self.n) def _op(self, a, b): if a is not None and b is not None: return self.op(a, b) if a is not None: return a return b st1 = SegTree(3, lambda a, b: a + b) assert st1.update(3, 1) == 1 assert st1.query(0, 3) == 1 assert st1.update(1, 2) == 3 assert st1.query(0, 2) == 2 assert st1.update(0, 1) == 4 assert st1.query(0, 3) == 4 assert st1.query(0, 0) == 1 st2 = SegTree(3, min) assert st2.update(3, 2) == 2 assert st2.query(3, 3) == 2 assert st2.update(0, 3) == 2 assert st2.query(0, 2) == 3 assert st2.update(1, 1) == 1 assert st2.query(0, 2) == 1
Lazy Segment Tree
Lazy Max Segment Tree
class SegTree: def __init__(self, start, end): self.start = start self.end = end self.left = None self.right = None self.value = float("-inf") def set(self, pos, value): if pos < self.start or self.end < pos: return if self.start == self.end: self.value = value return self._extend() self.left.set(pos, value) self.right.set(pos, value) self.value = max(self.left.value, self.right.value) def query(self, start, end): if end < self.start or self.end < start: return float("-inf") if start <= self.start and self.end <= end: return self.value self._extend() return max(self.left.query(start, end), self.right.query(start, end)) def _extend(self): if self.left is None and self.start < self.end: m = self.start + (self.end - self.start) // 2 self.left = SegTree(self.start, m) self.right = SegTree(m + 1, self.end) def print(self, level=0): print(" " * level, "SegTree", (self.start, self.end), self.value) if self.left: self.left.print(level + 1) self.right.print(level + 1) # sg = SegTree(0, 10) # sg.set(5, 2) # sg.set(1, 10) # assert sg.query(1, 3) == 10 # assert sg.query(2, 5) == 2 # sg.print()