LeetCode 924: Minimize Malware Spread

link

Simulate with BFS

For each node in the initial, with one BFS, we can simulate what the spread would have been if we removed that node. We want to find the node which if removed, would cause minimum spread. If there is a tie, we break in favor of smaller node id.

Say len(initial) = m. Time: \mathcal{O}(m \cdot n^2), space: \mathcal{O}(n^2).

class Solution:
    def spread_without(self, removed_node, graph, initial):
        infected = {s for s in initial if s != removed_node}
        q = deque(infected)
        while q:
            u = q.popleft()
            for v, is_connected in enumerate(graph[u]):
                if not is_connected or v in infected:
                    continue
                infected.add(v)
                q.append(v)

        return len(infected)

    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        spreads = []
        for u in initial:
            spreads.append((self.spread_without(u, graph, initial), u))

        spreads.sort()
        return spreads[0][1]

Union-Find

If two or more nodes from initial appear in the same connected component, removing them one at a time does help prevent spread. On the other hand, if a connected component has a single initial node, removing that node will prevent spread for the entire connected component. Using Union-Find, we keep track of the connected components. From the connected components, we can then find P(u) for each u in initial where removing u would have prevented spread to P(u) number of nodes. We want the node with the largest P(u).

Time: \mathcal{O}(n^2 \cdot \lg^*{n^2}), space: \mathcal{O}(n^2).

class Spread:
    def __init__(self, u, infected_cc_size, infected_cc_source_count):
        self.vertex = u
        if infected_cc_source_count > 1:
            self.prevention = 0
        else:
            self.prevention = infected_cc_size

    def __lt__(self, other_spread):
        if self.prevention != other_spread.prevention:
            return self.prevention > other_spread.prevention
        
        return self.vertex < other_spread.vertex


class UnionFind:
    def __init__(self, n):
        self.parent_of, self.rank_of = {}, {}
        for u in range(n):
            self.parent_of[u] = u
            self.rank_of[u] = 0

    def __compress_path(self, u, root):
        while u != (p := self.parent_of[u]):
            self.parent_of[u] = root
            u = p

    def __find(self, u):
        root = u
        while root != (p := self.parent_of[root]):
            root = p
        self.__compress_path(u, root)
        return root
        
    def union(self, u, v):
        root_u = self.__find(u)
        root_v = self.__find(v)
        if root_u == root_v:
            return
        
        rank_u = self.rank_of[root_u]
        rank_v = self.rank_of[root_v]
        if rank_u == rank_v:
            self.parent_of[root_u] = root_v
            self.rank_of[root_v] += 1
            return
        
        if rank_u < rank_v:
            self.parent_of[root_u] = root_v
        else:
            self.parent_of[root_v] = root_u

    
    def infected_ccs(self, initial):
        cc_size = {}
        for u in self.parent_of:
            root_u = self.__find(u)
            if root_u not in cc_size:
                cc_size[root_u] = 0
            cc_size[root_u] += 1

        cc_source_count = {}
        for u in initial:
            root_u = self.__find(u)
            if root_u not in cc_source_count:
                cc_source_count[root_u] = 0
            cc_source_count[root_u] += 1
        
        ccs = {}
        for u in initial:
            root_u = self.__find(u)
            ccs[u] = ( cc_size[root_u], cc_source_count[root_u] )

        return ccs

class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        uf = UnionFind(n)
        for u in range(n):
            for v, is_connected in enumerate(graph[u]):
                if is_connected:
                    uf.union(u, v)

        ccs = uf.infected_ccs(initial)
        spreads = sorted([Spread(u, ccs[u][0], ccs[u][1]) for u in initial])

        return spreads[0].vertex

Leave a comment