Leetcode: 1632. Rank Transform of a Matrix

Problem Statement

from typing import List
from collections import defaultdict


class UnionFind:
    def __init__(self):
        self.p = {}

    def find(self, u):
        if self.p[u] != u:
            self.p[u] = self.find(self.p[u])
        return self.p[u]

    def union(self, u, v):
        self.p.setdefault(u, u)
        self.p.setdefault(v, v)
        self.p[self.find(u)] = self.find(v)


class Solution:
    def matrixRankTransform(self, matrix: List[List[int]]) -> List[List[int]]:
        N = len(matrix)
        M = len(matrix[0])

        uf = defaultdict(UnionFind)
        for i in range(N):
            for j in range(M):
                uf[matrix[i][j]].union(i, ~j)

        components = defaultdict(lambda: defaultdict(list))
        for i in range(N):
            for j in range(M):
                value = matrix[i][j]
                u = uf[value].find(i)
                components[value][u].append((i, j))

        ar = [1 for _ in range(N)]
        ac = [1 for _ in range(M)]
        ans = [[None] * M for _ in range(N)]
        for c in sorted(components):
            for cells in components[c].values():
                m = max(max(ar[ni], ac[nj]) for ni, nj in cells)
                for ni, nj in cells:
                    ar[ni] = m + 1
                    ac[nj] = m + 1
                    ans[ni][nj] = m
        return ans


assert Solution().matrixRankTransform([[1, 2], [3, 4]]) == [[1, 2], [2, 3]]
assert Solution().matrixRankTransform([[7, 7], [7, 7]]) == [[1, 1], [1, 1]]
assert Solution().matrixRankTransform(
    [[20, -21, 14], [-19, 4, 19], [22, -47, 24], [-19, 4, 19]]
) == [[4, 2, 3], [1, 3, 4], [5, 1, 6], [1, 3, 4]]