Leetcode: 924. Minimize Malware Spread

Problem Statement

from typing import List


class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        N = len(graph)
        vis = [None] * N
        comp_size = [0] * N
        initial = set(initial)

        def dfs(u, p):
            for v in range(N):
                if graph[u][v] == 1 and vis[v] is None:
                    vis[v] = p
                    comp_size[p] += 1
                    dfs(v, p)

        for u in initial:
            if vis[u] is None:
                vis[u] = u
                comp_size[u] += 1
                dfs(u, u)

        comp_initial = [0] * N
        for u in initial:
            comp_initial[vis[u]] += 1

        ans = None
        for u in initial:
            improve = 0 if comp_initial[vis[u]] > 1 else comp_size[vis[u]]
            if ans is None or ans[1] < improve or (ans[1] == improve and ans[0] > u):
                ans = (u, improve)

        return ans[0] if ans else 0


assert Solution().minMalwareSpread([[1, 1, 0], [1, 1, 0], [0, 0, 1]], [0, 1]) == 0
assert Solution().minMalwareSpread([[1, 0, 0], [0, 1, 0], [0, 0, 1]], [0, 2]) == 0
assert Solution().minMalwareSpread([[1, 1, 1], [1, 1, 1], [1, 1, 1]], [1, 2]) == 1
from typing import List


class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        N = len(graph)
        initial.sort()

        p = [i for i in range(N)]

        def find(u):
            if p[u] != u:
                p[u] = find(p[u])
            return p[u]

        def union(u, v):
            p[find(u)] = p[find(v)]

        ans = None
        for u in range(N):
            for v in range(u + 1, N):
                if graph[u][v]:
                    union(u, v)

        s = [0] * N
        for u in range(N):
            s[find(u)] += 1

        ans = None
        for u in initial:
            if all(p[find(u)] != p[find(v)] for v in initial if u != v):
                if ans is None or s[find(ans)] < s[find(u)]:
                    ans = u
        return ans if ans is not None else initial[0]