Leetcode: 749. Contain Virus

Problem Statement

The complexity of the problem is due to the amount of code that you have to write instead an insight that you have to come up. Find each cluster (infected cells and cells to be infected by the cluster in the next turn) of virus using a Depth-first search. Sort the cluster and pick the one with most cells to be infected. Each step of this process will consume \(O(n \times m)\). So, the algorithm will be \(O(n \times m \times k)\) where \(k\) is the maximum number of rounds: \(k \leq (n \times m) / 9\) (maximum number of clusters). Therefore, the time complexity is \(O(n^2 \times m^2)\).

from typing import List


class Solution:
    def containVirus(self, infected: List[List[int]]) -> int:
        n = len(infected)
        m = len(infected[0])

        def neighbours(i, j):
            for di, dj in [(+0, +1), (+0, -1), (+1, +0), (-1, +0)]:
                if 0 <= i + di < n and 0 <= j + dj < m:
                    yield (i + di, j + dj)

        vis = [[False] * m for _ in range(n)]
        walled = [[False] * m for _ in range(n)]

        def find_cluster(cluster, i, j):
            if vis[i][j] or walled[i][j]:
                return cluster

            if not infected[i][j]:
                cluster["will_infect"].add((i, j))
                return cluster

            cluster["infected"].add((i, j))
            vis[i][j] = True
            for ni, nj in neighbours(i, j):
                find_cluster(cluster, ni, nj)
            return cluster

        ans = 0
        while True:
            for i in range(n):
                for j in range(m):
                    vis[i][j] = False

            clusters = []
            for i in range(n):
                for j in range(m):
                    if vis[i][j] or not infected[i][j] or walled[i][j]:
                        continue
                    clusters.append(
                        find_cluster({"infected": set(), "will_infect": set()}, i, j)
                    )

            if len(clusters) == 0:
                break

            clusters = sorted(clusters, key=lambda c: len(c["will_infect"]))
            walls = 0
            for i, j in clusters[-1]["will_infect"]:
                for ni, nj in neighbours(i, j):
                    if (ni, nj) in clusters[-1]["infected"]:
                        walls += 1
            ans += walls
            for i, j in clusters[-1]["infected"]:
                walled[i][j] = True

            for cluster in clusters[:-1]:
                for i, j in cluster["will_infect"]:
                    infected[i][j] = True

        return ans


assert (
    Solution().containVirus(
        [
            [0, 1, 0, 0, 0, 0, 0, 1],
            [0, 1, 0, 0, 0, 0, 0, 1],
            [0, 0, 0, 0, 0, 0, 0, 1],
            [0, 0, 0, 0, 0, 0, 0, 0],
        ]
    )
    == 10
)
assert Solution().containVirus([[1, 1, 1], [1, 0, 1], [1, 1, 1]]) == 4
assert (
    Solution().containVirus(
        [
            [1, 1, 1, 0, 0, 0, 0, 0, 0],
            [1, 0, 1, 0, 1, 1, 1, 1, 1],
            [1, 1, 1, 0, 0, 0, 0, 0, 0],
        ]
    )
    == 13
)
assert (
    Solution().containVirus(
        [
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ]
    )
    == 56
)