Leetcode: 1825. Finding MK Average
Problem Statement
from sortedcontainers import SortedList
from collections import deque
class FenwickTree:
def __init__(self, n):
self.n = n
self.bit = [0] * n
def sum(self, r, l=None):
if l is not None:
return self.sum(l) - self.sum(r - 1)
ret = 0
while r >= 0:
ret += self.bit[r]
r = (r & (r + 1)) - 1
return ret
def add(self, idx, delta):
while idx < self.n:
self.bit[idx] += delta
idx = idx | (idx + 1)
class MKAverage:
def __init__(self, m: int, k: int):
self.m = m
self.k = k
self.cnt = self.m - 2 * self.k
self.ac = FenwickTree(100_001)
self.nums = deque(maxlen=m + 1)
self.sorted = SortedList()
def addElement(self, num: int) -> None:
if len(self.nums) == self.m:
r = self.nums.popleft()
self.sorted.remove(r)
self.ac.add(r, -r)
self.nums.append(num)
self.sorted.add(num)
self.ac.add(num, num)
def calculateMKAverage(self) -> int:
if len(self.nums) < self.m:
return -1
i = self.k
j = self.m - self.k - 1
start_num = self.sorted[i]
end_num = self.sorted[j]
if start_num == end_num:
return start_num
s = self.ac.sum(start_num, end_num + 1)
s -= end_num * (self.sorted.bisect_right(end_num) - j - 1)
s -= start_num * (self.k - self.sorted.bisect_left(start_num))
return s // self.cnt
obj = MKAverage(3, 1)
obj.addElement(3)
obj.addElement(1)
assert obj.calculateMKAverage() == -1
obj.addElement(10)
assert obj.calculateMKAverage() == 3
obj.addElement(5)
obj.addElement(5)
obj.addElement(5)
assert obj.calculateMKAverage() == 5
class FenwickTree:
def __init__(self, n):
self.n = n
self.bit = [0] * n
def sum(self, r, l=None):
if l is not None:
return self.sum(l) - self.sum(r - 1)
ret = 0
while r >= 0:
ret += self.bit[r]
r = (r & (r + 1)) - 1
return ret
def add(self, idx, delta):
while idx < self.n:
self.bit[idx] += delta
idx = idx | (idx + 1)
class MKAverage:
def __init__(self, m: int, k: int):
self.m = m
self.k = k
self.r = m - 2 * k
self.c = FenwickTree(10**5 + 1)
self.n = FenwickTree(10**5 + 1)
self.stream = deque()
def addElement(self, num: int) -> None:
if len(self.stream) == self.m:
x = self.stream.popleft()
self.c.add(x, -1)
self.n.add(x, -x)
self.stream.append(num)
self.c.add(num, +1)
self.n.add(num, +num)
def index(self, v):
s = 1
e = 10**5
while s < e:
m = s + (e - s) // 2
if self.c.sum(m) >= v:
e = m
else:
s = m + 1
return s
def calculateMKAverage(self) -> int:
if len(self.stream) < self.m:
return -1
s = self.index(self.k)
e = self.index(self.m - self.k)
ans = self.n.sum(e) - self.n.sum(s)
ans += s * max(0, self.c.sum(s) - self.k)
ans -= e * max(0, self.c.sum(e) - (self.m - self.k))
return floor(ans / self.r)
- We can replace the sorted list and FenwickTree by implementing a Segment Tree that keeps the sum and a count for each node. Time complexity is \(O(\log M)\) for all operations and space is \(O(M)\) where \(M\) is the number of distinct numbers allowed.
from collections import deque
class SegTree:
def __init__(self, start, end):
self.start = start
self.end = end
self.sum = 0
self.count = 0
self.left = None
self.right = None
def add(self, v):
if v < self.start or self.end < v:
return
self.sum += v
self.count += 1
if self.start == self.end:
return
self._extend()
self.left.add(v)
self.right.add(v)
def remove(self, v):
if v < self.start or self.end < v:
return
self.sum -= v
self.count -= 1
if self.start == self.end:
return
self._extend()
self.left.remove(v)
self.right.remove(v)
def query(self, count):
if count == 0:
return 0
if self.start == self.end:
return self.start * count
if self.left.count < count:
return self.left.sum + self.right.query(count - self.left.count)
else:
return self.left.query(count)
def _extend(self):
if self.left is None and self.start < self.end:
m = self.start + (self.end - self.start) // 2
self.left = SegTree(self.start, m)
self.right = SegTree(m + 1, self.end)
def print(self, level=0):
print(
" " * level,
"SegTree",
(self.start, self.end),
"sum",
self.sum,
"count",
self.count,
)
if self.left:
self.left.print(level + 1)
self.right.print(level + 1)
class MKAverage:
def __init__(self, m: int, k: int):
self.m = m
self.k = k
self.sg = SegTree(0, 100_001)
self.nums = deque(maxlen=m + 1)
def addElement(self, num: int) -> None:
if len(self.nums) == self.m:
self.sg.remove(self.nums.popleft())
self.sg.add(num)
self.nums.append(num)
def calculateMKAverage(self) -> int:
if len(self.nums) < self.m:
return -1
return (self.sg.query(self.m - self.k) - self.sg.query(self.k)) // (
self.m - 2 * self.k
)
obj = MKAverage(3, 1)
obj.addElement(3)
obj.addElement(1)
assert obj.calculateMKAverage() == -1
obj.addElement(10)
assert obj.calculateMKAverage() == 3
obj.addElement(5)
obj.addElement(5)
obj.addElement(5)
assert obj.calculateMKAverage() == 5
- Can we simplify the problem while keeping it the same? Keep tree lists \(left, mid\) and \(right\) with the last \(m\) elements from the stream. After receiving a list, first remove one if it exceeds the limit and then update the lists to respect its sizes from left to right. Time complexity is \(O(n \log n)\) and space complexity is \(O(n)\).
from sortedcontainers import SortedList
class MKAverage:
def __init__(self, m: int, k: int):
self.m = m
self.k = k
self.r = m - 2 * k
self.left = SortedList()
self.right = SortedList()
self.mid = SortedList()
self.sum = 0
self.nums = deque()
def addElement(self, num: int) -> None:
if len(self.nums) == self.m:
self.popleft()
self.nums.append(num)
self.left.add(num)
if len(self.left) > self.k:
x = self.left.pop(-1)
self.mid.add(x)
self.sum += x
if len(self.mid) > self.r:
x = self.mid.pop(-1)
self.right.add(x)
self.sum -= x
def popleft(self) -> None:
num = self.nums.popleft()
if num in self.left:
self.left.remove(num)
elif num in self.mid:
self.mid.remove(num)
self.sum -= num
else:
self.right.remove(num)
if len(self.left) < self.k:
x = self.mid.pop(0)
self.left.add(x)
self.sum -= x
if len(self.mid) < self.r:
x = self.right.pop(0)
self.mid.add(x)
self.sum += x
def calculateMKAverage(self) -> int:
if len(self.nums) < self.m:
return -1
return floor(self.sum / len(self.mid))