Leetcode: 1579. Remove Max Number of Edges to Keep Graph Fully Traversable

Problem Statement

from typing import List


class Solution:
    def maxNumEdgesToRemove(self, n: int, edges: List[List[int]]) -> int:
        p = {}

        def find(u):
            if u not in p:
                p[u] = u
            if p[u] != u:
                p[u] = find(p[u])
            return p[u]

        def union(u, v):
            p[find(u)] = find(v)

        added = 0
        for t, u, v in sorted(edges, key=lambda e: e[0], reverse=True):
            if t == 3:
                if find(u) != find(v):
                    union(u, v)
                    union(~u, ~v)
                    added += 1
            elif t == 2:
                if find(~u) != find(~v):
                    union(~u, ~v)
                    added += 1
            else:
                if find(u) != find(v):
                    union(u, v)
                    added += 1

        valid = (
            len({find(u) for u in range(1, n + 1)})
            == len({find(~u) for u in range(1, n + 1)})
            == 1
        )
        return -1 if not valid else len(edges) - added


assert (
    Solution().maxNumEdgesToRemove(
        4, [[3, 1, 2], [3, 2, 3], [1, 1, 3], [1, 2, 4], [1, 1, 2], [2, 3, 4]]
    )
    == 2
)
assert (
    Solution().maxNumEdgesToRemove(4, [[3, 1, 2], [3, 2, 3], [1, 1, 4], [2, 1, 4]]) == 0
)
assert Solution().maxNumEdgesToRemove(4, [[3, 2, 3], [1, 1, 2], [2, 3, 4]]) == -1