Leetcode: 778. Swim in Rising Water

Problem Statement

from typing import List
from heapq import heappush, heappop


class Solution:
    def swimInWater(self, grid: List[List[int]]) -> int:
        N = len(grid)

        seen = set((grid[0][0], 0, 0))
        pq = [(grid[0][0], 0, 0)]

        while pq:
            t, i, j = heappop(pq)

            if i == N - 1 and j == N - 1:
                return t

            for ni, nj in [[i, j + 1], [i, j - 1], [i + 1, j], [i - 1, j]]:
                if 0 <= ni < N and 0 <= nj < N and (ni, nj) not in seen:
                    seen.add((ni, nj))
                    heappush(pq, (max(grid[ni][nj], t), ni, nj))


assert Solution().swimInWater([[0, 2], [1, 3]]) == 3
assert (
    Solution().swimInWater(
        [
            [0, 1, 2, 3, 4],
            [24, 23, 22, 21, 5],
            [12, 13, 14, 15, 16],
            [11, 17, 18, 19, 20],
            [10, 9, 8, 7, 6],
        ]
    )
    == 16
)
class Solution:
    def swimInWater(self, grid: List[List[int]]) -> int:
        N = len(grid)
        M = len(grid[0])
        D = [[+1, +0], [-1, +0], [+0, +1], [+0, -1]]

        p = {}

        def find(u):
            p.setdefault(u, u)
            if p[u] != u:
                p[u] = find(p[u])
            return p[u]

        def union(u, v):
            p[find(u)] = p[find(v)]

        for t, i, j in sorted((grid[i][j], i, j) for i in range(N) for j in range(M)):
            for di, dj in D:
                ni, nj = i + di, j + dj
                if 0 <= ni < N and 0 <= nj < M and grid[ni][nj] <= t:
                    union((i, j), (ni, nj))
            if find((0, 0)) == find((N - 1, M - 1)):
                return t