Leetcode: 1289. Minimum Falling Path Sum II

Problem Statement

from typing import List


class Solution:
    def minFallingPathSum(self, grid: List[List[int]]) -> int:
        N = len(grid)
        M = len(grid[0])

        if N == 1 and M == 1:
            return grid[0][0]

        dp = [[float("inf")] * M for _ in range(N)]
        left = [float("inf")] * M
        right = [float("inf")] * M

        for i in range(N):
            for j in range(M):
                if i == 0:
                    dp[i][j] = grid[i][j]
                    continue

                dp[i][j] = grid[i][j]
                if j == 0:
                    dp[i][j] += right[j + 1]
                elif j == M - 1:
                    dp[i][j] += left[j - 1]
                else:
                    dp[i][j] += min(left[j - 1], right[j + 1])

            if i < M - 1:
                for j in range(M):
                    left[j] = min(dp[i][j], float("inf") if j == 0 else left[j - 1])
                for j in range(M - 1, -1, -1):
                    right[j] = min(
                        dp[i][j], float("inf") if j == M - 1 else right[j + 1]
                    )

        return min(dp[N - 1][j] for j in range(M))


assert Solution().minFallingPathSum([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) == 13
assert Solution().minFallingPathSum([[7]]) == 7