Leetcode: 1825. Finding MK Average

Problem Statement

from sortedcontainers import SortedList
from collections import deque


class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * n

    def sum(self, r, l=None):
        if l is not None:
            return self.sum(l) - self.sum(r - 1)
        ret = 0
        while r >= 0:
            ret += self.bit[r]
            r = (r & (r + 1)) - 1
        return ret

    def add(self, idx, delta):
        while idx < self.n:
            self.bit[idx] += delta
            idx = idx | (idx + 1)


class MKAverage:
    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        self.cnt = self.m - 2 * self.k
        self.ac = FenwickTree(100_001)
        self.nums = deque(maxlen=m + 1)
        self.sorted = SortedList()

    def addElement(self, num: int) -> None:
        if len(self.nums) == self.m:
            r = self.nums.popleft()
            self.sorted.remove(r)
            self.ac.add(r, -r)
        self.nums.append(num)
        self.sorted.add(num)
        self.ac.add(num, num)

    def calculateMKAverage(self) -> int:
        if len(self.nums) < self.m:
            return -1

        i = self.k
        j = self.m - self.k - 1
        start_num = self.sorted[i]
        end_num = self.sorted[j]

        if start_num == end_num:
            return start_num

        s = self.ac.sum(start_num, end_num + 1)
        s -= end_num * (self.sorted.bisect_right(end_num) - j - 1)
        s -= start_num * (self.k - self.sorted.bisect_left(start_num))
        return s // self.cnt


obj = MKAverage(3, 1)
obj.addElement(3)
obj.addElement(1)
assert obj.calculateMKAverage() == -1
obj.addElement(10)
assert obj.calculateMKAverage() == 3
obj.addElement(5)
obj.addElement(5)
obj.addElement(5)
assert obj.calculateMKAverage() == 5
class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * n

    def sum(self, r, l=None):
        if l is not None:
            return self.sum(l) - self.sum(r - 1)
        ret = 0
        while r >= 0:
            ret += self.bit[r]
            r = (r & (r + 1)) - 1
        return ret

    def add(self, idx, delta):
        while idx < self.n:
            self.bit[idx] += delta
            idx = idx | (idx + 1)


class MKAverage:
    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        self.r = m - 2 * k
        self.c = FenwickTree(10**5 + 1)
        self.n = FenwickTree(10**5 + 1)
        self.stream = deque()

    def addElement(self, num: int) -> None:
        if len(self.stream) == self.m:
            x = self.stream.popleft()
            self.c.add(x, -1)
            self.n.add(x, -x)
        self.stream.append(num)
        self.c.add(num, +1)
        self.n.add(num, +num)

    def index(self, v):
        s = 1
        e = 10**5
        while s < e:
            m = s + (e - s) // 2
            if self.c.sum(m) >= v:
                e = m
            else:
                s = m + 1
        return s

    def calculateMKAverage(self) -> int:
        if len(self.stream) < self.m:
            return -1

        s = self.index(self.k)
        e = self.index(self.m - self.k)
        ans = self.n.sum(e) - self.n.sum(s)
        ans += s * max(0, self.c.sum(s) - self.k)
        ans -= e * max(0, self.c.sum(e) - (self.m - self.k))
        return floor(ans / self.r)
from collections import deque


class SegTree:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.sum = 0
        self.count = 0
        self.left = None
        self.right = None

    def add(self, v):
        if v < self.start or self.end < v:
            return
        self.sum += v
        self.count += 1
        if self.start == self.end:
            return
        self._extend()
        self.left.add(v)
        self.right.add(v)

    def remove(self, v):
        if v < self.start or self.end < v:
            return
        self.sum -= v
        self.count -= 1
        if self.start == self.end:
            return
        self._extend()
        self.left.remove(v)
        self.right.remove(v)

    def query(self, count):
        if count == 0:
            return 0
        if self.start == self.end:
            return self.start * count
        if self.left.count < count:
            return self.left.sum + self.right.query(count - self.left.count)
        else:
            return self.left.query(count)

    def _extend(self):
        if self.left is None and self.start < self.end:
            m = self.start + (self.end - self.start) // 2
            self.left = SegTree(self.start, m)
            self.right = SegTree(m + 1, self.end)

    def print(self, level=0):
        print(
            " " * level,
            "SegTree",
            (self.start, self.end),
            "sum",
            self.sum,
            "count",
            self.count,
        )
        if self.left:
            self.left.print(level + 1)
            self.right.print(level + 1)


class MKAverage:
    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        self.sg = SegTree(0, 100_001)
        self.nums = deque(maxlen=m + 1)

    def addElement(self, num: int) -> None:
        if len(self.nums) == self.m:
            self.sg.remove(self.nums.popleft())
        self.sg.add(num)
        self.nums.append(num)

    def calculateMKAverage(self) -> int:
        if len(self.nums) < self.m:
            return -1
        return (self.sg.query(self.m - self.k) - self.sg.query(self.k)) // (
            self.m - 2 * self.k
        )


obj = MKAverage(3, 1)
obj.addElement(3)
obj.addElement(1)
assert obj.calculateMKAverage() == -1
obj.addElement(10)
assert obj.calculateMKAverage() == 3
obj.addElement(5)
obj.addElement(5)
obj.addElement(5)
assert obj.calculateMKAverage() == 5
from sortedcontainers import SortedList

class MKAverage:

    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        self.r = m - 2 * k
        self.left = SortedList()
        self.right = SortedList()
        self.mid = SortedList()
        self.sum = 0
        self.nums = deque()

    def addElement(self, num: int) -> None:
        if len(self.nums) == self.m:
            self.popleft()
        self.nums.append(num)
        self.left.add(num)
        if len(self.left) > self.k:
            x = self.left.pop(-1)
            self.mid.add(x)
            self.sum += x
        if len(self.mid) > self.r:
            x = self.mid.pop(-1)
            self.right.add(x)
            self.sum -= x

    def popleft(self) -> None:
        num = self.nums.popleft()
        if num in self.left:
            self.left.remove(num)
        elif num in self.mid:
            self.mid.remove(num)
            self.sum -= num
        else:
            self.right.remove(num)
        if len(self.left) < self.k:
            x = self.mid.pop(0)
            self.left.add(x)
            self.sum -= x
        if len(self.mid) < self.r:
            x = self.right.pop(0)
            self.mid.add(x)
            self.sum += x

    def calculateMKAverage(self) -> int:
        if len(self.nums) < self.m:
            return -1
        return floor(self.sum / len(self.mid))