LeetCode 1442: Count Triplets That Can Form Two Arrays of Equal XOR

link

All triplets

For all (i, j, k), we check if xor(i, j) = xor(j, k+1).

Time: \mathcal{O}(n^4), space: \mathcal{O}(1).

class Solution:
    def xor(self, begin, end, arr):
        xor = 0
        for i in range(begin, end):
            xor ^= arr[i]
        return xor

    def countTriplets(self, arr: List[int]) -> int:
        count = 0
        for k in range(1, len(arr)):
            for i in range(0, k):
                for j in range(i+1, k+1):
                    if self.xor(i, j, arr) == self.xor(j, k+1, arr):
                        count += 1

        return count

Prefix XOR

Say XOR(i, j) represents the xor of the numbers in the sublist nums[i : j+1], then XOR(i, j) = XOR(0, j) \oplus XOR(0, i-1). So, we could precompute xor’s of prefixes and avoid computing sublist xor’s inside the inner loop.

Time: \mathcal{O}(n^3), space: \mathcal{O}(n).

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        prefix_xor = [0] * len(arr)
        prefix_xor[0] = arr[0]
        for i in range(1, len(arr)):
            prefix_xor[i] = prefix_xor[i-1] ^ arr[i]

        count = 0
        for k in range(1, len(arr)):
            for i in range(0, k):
                for j in range(i+1, k+1):
                    left_xor = prefix_xor[j-1] ^ (prefix_xor[i-1] if i > 0 else 0)
                    right_xor = prefix_xor[k] ^ prefix_xor[j-1]
                    if left_xor == right_xor:
                        count += 1

        return count

Two more observations:

  1. XOR(i, j-1) = XOR(j, k) means XOR(i, k) = 0, also XOR(0, i-1) \oplus XOR(i, k) = XOR(0, i-1). So, we have XOR(0, k) = XOR(0, i-1). For a prefix xor, we need to find an earlier prefix xor that is equal.
  2. XOR(i, k) = 0 can be written as X(i, i) = XOR(i+1, k) or X(i, i+1) = XOR(i+2, k), etc. So, we have (k-i) valid triplets.

So, in the prefix_xor, whenever we have prefix_xor[i] = prefix_xor[k] we know XOR(i+1, k) = 0 and we can add k-(i+1) to the count.

Time: \mathcal{O}(n^2), space: \mathcal{O}(n).

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        prefix_xor = [0] * len(arr)
        prefix_xor[0] = arr[0]
        for i in range(1, len(arr)):
            prefix_xor[i] = prefix_xor[i-1] ^ arr[i]

        count = 0
        for i in range(-1, len(arr)):
            for j in range(i+1, len(arr)):
                left_xor = prefix_xor[i] if i >= 0 else 0
                right_xor = prefix_xor[j]
                if left_xor == right_xor:
                    count += (j-i-1)

        return count

Leave a comment