LeetCode 2602. Minimum Operations to Make All Array Elements Equal

link

Brute force

For query q, number of ops is \sum_{i} |x_i - q| .

Time: \mathcal{O}(n \cdot m), space: \mathcal{O}(1).

class Solution:
    def minOperations(self, nums: List[int], queries: List[int]) -> List[int]:
        n, m = len(nums), len(queries)
        num_ops = [0]*m
        for i, q in enumerate(queries):
            num_ops[i] = sum( abs(q-x) for x in nums )
        return num_ops

Sort and search

If \forall x, q > x, we can remove the | \cdot |. Then, number of ops is \sum_i (q - x_i) = n \cdot q - \sum_i x_i. So, for each q, we can find the number of ops in \mathcal{O}(1) time. Similarly, if \forall x, q < x, the number of ops, \sum_i x_i - n \cdot q, can be found in \mathcal{O}(1) time.

The two cases are mixed up in nums. We can sort nums to separate out the cases.

We can binary-search in nums to find where q should have been in sorted order and from that we can find the region where x_i < q and the region where x_i > q. We need to efficiently find the prefix or suffix sums, so we precompute cumulative sums.

Time: \mathcal{O}( \max{(m, n)} \cdot \lg{n} ), space: \mathcal{O}( \max{(m, n)} ).

class Solution:
    def find_insert_index(self, x, nums) -> int:
        lo, hi = 0, len(nums)-1
        while lo <= hi:
            mid = (lo+hi) // 2
            if nums[mid] == x:
                return mid
            if nums[mid] < x:
                lo = mid+1
            else:
                hi = mid-1
        return lo

    def find_cumulative_sum(self, nums) -> List[int]:
        cumsum = [nums[0]] + [0] * (len(nums)-1)
        for i, x in enumerate(nums[1:], start=1):
            cumsum[i] = cumsum[i-1] + x
        return cumsum

    def count_increments(self, q, q_pos, cumsum) -> int:
        num_less = q_pos
        if num_less == 0:
            # For all x, q < x.
            # No increment is necessary
            return 0
        
        sum_x = cumsum[q_pos-1]
        return num_less*q - sum_x

    def count_decrements(self, q, q_pos, cumsum) -> int:
        n = len(cumsum)
        num_more = n - q_pos
        if num_more == 0:
            # For all x, q > x
            # No decrement is necessary
            return 0
        
        sum_x = cumsum[n-1] - (cumsum[q_pos-1] if q_pos > 0 else 0)
        return sum_x - num_more*q
        
    def minOperations(self, nums: List[int], queries: List[int]) -> List[int]:
        n, m = len(nums), len(queries)
        nums.sort()
        cumsum = self.find_cumulative_sum( nums )

        num_ops = [0]*m
        for i, q in enumerate(queries):
            q_pos = self.find_insert_index(q, nums)

            n_inc = self.count_increments(q, q_pos, cumsum)
            n_dec = self.count_decrements(q, q_pos, cumsum)

            num_ops[i] = n_inc + n_dec

        return num_ops

Leave a comment