Leetcode: 1425. Constrained Subsequence Sum

Problem Statement: Given an array of integers \(a\) and an integer \(k\), return the maximum sum of a non-empty subsequence where adjacent elements are at most k distance in \(n\).

Let's use Dynamic Programming to find this value. Be \(f(i)\) the maximum sum of a subsequence starting on \(i\). We can use \(f\) to solve the problem: \(\max_{0 \leq i \leq i} f(i)\), but how can we compute \(f\) efficiently? The naive approach is would be

\begin{equation*} f(i)=\begin{cases} a[i], & \mbox{if $i = |a| - 1$} \\ \max(a[i], a[i]+\max_{i < j \leq i + k - 1}_{} f(j)), & \mbox{otherwise}. \end{cases} \end{equation*}

If we iterate over \(j\), the time complexity of the algorithm will be \(|a| \times k\) what is too slow for the problem. We have to efficiently compute \(\max_{i < j \leq i + k - 1}_{} f(j)\) for all \(i\). This problem can be solved using Sliding Window Maximum data structure which has amortized time complexity of \(O(1)\) for the add operation.

from collections import deque
from typing import List

class SlidingWindowMaximum:
    def __init__(self, size):
        self.size = size
        self.window = deque(maxlen=size)
        self.index = 0

    def add(self, value):
        while len(self.window) > 0 and self.window[0][1] <= self.index - self.size:

        while len(self.window) > 0 and self.window[-1][0] <= value:

        self.window.append((value, self.index))
        self.index += 1

        return self.max()

    def max(self):
        return self.window[0][0]

def solve(nums, k):
    INF = 10**9

    dp = [-INF] * len(nums)
    swm = SlidingWindowMaximum(k)

    dp[-1] = nums[-1]

    ans = -INF
    for i in range(len(nums) - 2, -1, -1):
        dp[i] = max(nums[i], nums[i] + swm.max())
        ans = max(ans, swm.add(dp[i]))

    return ans

assert solve([10, 2, -10, 5, 20], 2) == 37
assert solve([-1, -2, -3], 1) == -1
assert solve([10, -2, -10, -5, 20], 2) == 23

class Solution:
    def constrainedSubsetSum(self, nums: List[int], k: int) -> int:
        return solve(nums, k)