Leetcode: 2276. Count Integers in Intervals
Problem Statement
class SegTree:
def __init__(self, start, end):
self.start = start
self.end = end
self.left = None
self.right = None
self.value = 0
self.pending = None
def update(self, start, end, force=False):
if self.start > end or self.end < start:
return
if self.value == self.end - self.start + 1:
return
if start <= self.start and self.end <= end:
self.left = None
self.right = None
self.value = self.end - self.start + 1
return
if not force and self.pending is None and self.value == 0:
self.pending = (start, end)
self.value = min(self.end, end) - max(self.start, start) + 1
return
self._extend()
self.left.update(start, end)
self.right.update(start, end)
self.value = self.left.value + self.right.value
def _extend(self):
if self.pending:
start, end = self.pending
self.pending = None
self.value = 0
self.update(start, end, True)
if self.left is None:
m = self.start + (self.end - self.start) // 2
self.left = SegTree(self.start, m)
self.right = SegTree(m + 1, self.end)
def print(self, level=0):
print(" " * level, (self.start, self.end), self.value, self.pending)
if self.left:
self.left.print(level + 1)
self.right.print(level + 1)
class CountIntervals:
def __init__(self):
self.st = SegTree(0, 10**9)
def add(self, left: int, right: int) -> None:
self.st.update(left, right)
def count(self) -> int:
return self.st.value
- We can use a sorted interval list to solve the problem in \(O(n \log n)\) where \(n\) is the number of queries.
class CountIntervals:
def __init__(self):
self.segs = []
self.value = 0
def add(self, left: int, right: int) -> None:
i = bisect_left(self.segs, left, key=lambda e: e[0])
j = i = i - (1 if i > 0 and self.segs[i - 1][1] >= left else 0)
while j < len(self.segs) and self.segs[j][0] <= right:
self.value -= self.segs[j][1] - self.segs[j][0] + 1
left = min(left, self.segs[j][0])
right = max(right, self.segs[j][1])
j += 1
self.value += right - left + 1
self.segs[i:j] = [(left, right)]
def count(self) -> int:
return self.value