Leetcode: 2421. Number of Good Paths

Problem Statement

class Solution:
    def numberOfGoodPaths(self, vals: List[int], edges: List[List[int]]) -> int:
        N = len(vals)
        p = [u for u in range(N)]
        s = [Counter([vals[u]]) for u in range(N)]
        c = [1 for u in range(N)]

        gmax = lambda e: max(vals[e[0]], vals[e[1]])
        gmin = lambda e: min(vals[e[0]], vals[e[1]])

        edges.sort(key=lambda e: (gmax(e), gmin(e)))

        def find(u):
            if p[u] != u:
                p[u] = find(p[u])
            return p[u]

        def union(u, v):
            pu = find(u)
            pv = find(v)
            if pu != pv:
                m = max(vals[u], vals[v])
                s[pv][m] += s[pu][m]
                # c[pv] += c[pu]
                p[pu] = p[pv]

        ans = N
        for u, v in edges:
            if vals[u] < vals[v]:
                u, v = v, u
            pu = find(u)
            pv = find(v)
            ans += s[pu][vals[u]] * s[pv][vals[u]]
            union(u, v)
        return ans