Leetcode: 402. Remove K Digits

Problem Statement: Given a string representing a number \(n\) and an integer \(k\), return the smallest number after removing \(k\) digits from \(n\).

Suppose that \(k=1\). What digit should we remove from \(n\) to create the smaller possible number? If \(n=9123\), we have to remove 9. If \(n=1834\), we have to remove 8. If \(n=1892\), we also have to remove 9 resulting in 182, since removing 9 result in 182. Finally, if \(n=1234\), we have to remove 4. What is common in all these examples? We removed the first digit \(i\) where \(n[i]>n[i+1]\) or the last one if \(i\) doesn't exist. So, we can solve the problem by finding \(i\) and removing it from the number, and repeating the same process more \(k-1\) times. The time complexity of this solution is \(O(|n| \times k)\) since it spends \(O(|n|)\) to remove one digit and there are \(k\) of them to remove.

def should_keep(nums, i):
    if i + 1 == len(nums):
        return False
    if nums[i] <= nums[i+1]:
        return True
    return False


def remove_leading_zeros(nums):
    if len(nums) > 1 and nums[0] == "0":
        return remove_leading_zeros(nums[1:])
    return nums


def brute_force(nums, k):
    if k == 0:
        return remove_leading_zeros(nums)
    if len(nums) == 1:
        return "0"
    i = 0
    while should_keep(nums, i):
        i += 1
    return brute_force(nums[0:i] + nums[i+1:], k - 1)


assert brute_force("1432219", 3) == "1219"
assert brute_force("10200", 1) == "200"
assert brute_force("10", 2) == "0"
assert brute_force("1", 1) == "0"
assert brute_force("1", 0) == "1"
assert brute_force("112", 1) == "11"

Can we reuse or extend a solution from a sub-problem to solve the next sub-problem more efficiently? The previous solution breaks the original problem in smaller ones by removing digit-by-digit. So, after removing \(k-1\) digits, it is easy to remove the \(k\)th one, but the above solution doesn't reuse any computation done so far to remove the \(k\)th one quickly. Then, the question can be rephrased as following: Can we reuse or extend a solution of removing \(k-1\) digits to remove the $k$-th digit more efficiently? Be \(i\) the index of the first index removed from the input. We know that \(n[0] \leq n[1] \leq ... \leq n[i]\) and either \(i=|n|\) or \(n[i] > n[i+1]\). In brute_force solution, we always start from the index 0 to find the next digit to remove, but it wastes time since we know that the next digit to remove won't be the first \(i\) positions of the current number. Therefore, we can use this fact to speed up the search for the next digit.

def should_keep(num, i):
    if i + 1 >= len(num):
        return False
    if num[i] <= num[i+1]:
        return True
    return False


def remove_leading_zeros(num):
    if len(num) > 1 and num[0] == "0":
        return remove_leading_zeros(num[1:])
    return num


def solve_slow(num, k, i=0):
    if k == 0:
        return remove_leading_zeros(num)
    if len(num) == 1:
        return "0"
    while should_keep(num, i):
        i += 1
    return solve_slow(num[0:i] + num[i+1:], k - 1, i - 1)


assert solve_slow("1432219", 3) == "1219"
assert solve_slow("10200", 1) == "200"
assert solve_slow("10", 2) == "0"
assert solve_slow("1", 1) == "0"
assert solve_slow("1", 0) == "1"
assert solve_slow("9991", 3) == "1"
assert solve_slow("112", 1) == "11"

The above solution is an improvement on brute_force solution, but it is still \(O(|n| \times k)\), because slicing strings on Python has time proportional to the length of them. Looking carefully, you might see that each call of solve_slow splits \(n\) in two parts: the left part has the number on increasing order (Monotonic Stack) and the right part has all the numbers that we still have to process. If we keep both parts separate, we won't spend a time proportional to the length of their sums to join them.

from collections import deque


def solve(num, k):
    ans = deque(maxlen=len(num))
    for digit in num:
        if k == 0:
            ans.append(digit)
            continue
        while k > 0 and len(ans) > 0 and ans[len(ans) - 1] > digit:
            ans.pop()
            k -= 1
        ans.append(digit)

    # ans is in increasing order and we still have digits to remove
    while k > 0 and len(ans) > 0:
        ans.pop()
        k -= 1

    # remove leading zeros
    while len(ans) > 0 and ans[0] == "0":
        ans.popleft()

    if len(ans) == 0:
        return "0"

    return "".join(ans)


assert solve("1432219", 3) == "1219"
assert solve("10200", 1) == "200"
assert solve("10", 2) == "0"
assert solve("1", 1) == "0"
assert solve("1", 0) == "1"
assert solve("9991", 3) == "1"
assert solve("112", 1) == "11"


class Solution:
    def removeKdigits(self, num: str, k: int) -> str:
        return solve(num, k)