LeetCode 1802: Maximum Value at a Given Index in a Bounded Array

link

Since all numbers in nums must be positive, bounded_sum(nums) must be at least n, so maxSum >= n. As we increase nums[index] the predicate bounded_sum(nums) > maxSum evaluates like: False, False, False, True, True, … In other words, once the predicate becomes true, it remains True. We want the value of nums[index] for the rightmost False evaluation. So, we use binary search for nums[index] in [1, maxSum].

Time: \mathcal{O}(\lg{(\text{maxSum})} \cdot n), space: \mathcal{O}(1).

class Solution:
    def maxValue(self, n: int, index: int, maxSum: int) -> int:
        def bounded_sum(val):
            left_sum = 0
            left = index-1
            decr = 1
            while left >= 0:
                left_sum += max(1, val - decr)
                left -= 1
                decr += 1

            right_sum = 0
            right = index+1
            decr = 1
            while right < n:
                right_sum += max(1, val - decr)
                right += 1
                decr += 1
            
            return left_sum + val + right_sum

        lo, hi = 1, maxSum
        while lo <= hi:
            mid = (lo + hi) // 2
            if bounded_sum(mid) > maxSum:
                hi = mid-1
            else:
                lo = mid+1
        
        return lo-1

Increase in nums[index] propagates to left and right in a pattern. So, using series sum we can compute bounded_sum in constant time.

Time: \mathcal{O}(\lg{(\text{maxSum})}), space: \mathcal{O}(1).

class Solution:
    def maxValue(self, n: int, index: int, maxSum: int) -> int:
        def side_sum(max_val, begin, end):
            side_len = end - begin + 1
            incremented_side_len = min(side_len, max_val - 2)
            m = max_val - 1
            incremented_sum = (
                m * (m + 1) // 2
                - (m - incremented_side_len) * (m - incremented_side_len + 1) // 2
            )

            return incremented_sum + (side_len - incremented_side_len)

        def bounded_sum(max_val):
            if max_val <= 2:
                return n + (max_val - 1)

            return (
                side_sum(max_val, 0, index - 1)
                + (max_val)
                + side_sum(max_val, index + 1, n - 1)
            )

        #  0 1 2 3 4 5 6 7
        # [1 1 2 3 4 5 4 3]
        #  | | | | | | | |
        #  0 0 1 2 3 4 3 2
        #  | | | | | | | |
        # [1 1 1 1 1 1 1 1]

        # [    <=    |     >    ]
        #            |lo
        lo, hi = 1, maxSum
        while lo <= hi:
            mid = (lo + hi) // 2
            if bounded_sum(mid) > maxSum:
                hi = mid - 1
            else:
                lo = mid + 1

        return lo - 1

Leave a comment