Leetcode 1803. Count Pairs With XOR in a Range Solution

Problem Explanation

We are given an integer array nums (0-indexed) and two integers low and high. Our task is to return the number of nice pairs. A nice pair is defined as a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Example Walkthrough

Let's consider this example:

1nums = [1, 4, 2, 7], low = 2, high = 6

All nice pairs (i, j) are as follows:

  • (0, 1): nums[0] XOR nums[1] = 1 XOR 4 = 5
  • (0, 2): nums[0] XOR nums[2] = 1 XOR 2 = 3
  • (0, 3): nums[0] XOR nums[3] = 1 XOR 7 = 6
  • (1, 2): nums[1] XOR nums[2] = 4 XOR 2 = 6
  • (1, 3): nums[1] XOR nums[3] = 4 XOR 7 = 3
  • (2, 3): nums[2] XOR nums[3] = 2 XOR 7 = 5

So the output is: 6

Solution Approach

We will implement a solution that utilizes Trie data structure to count the number of nice pairs efficiently. The Trie would have 2 children for each bit, representing if a bit is 0 or 1. The count variable represents the count of numbers on a given node that has the same bits on this node. Our solution iterates over all elements in the nums list and insert them in the Trie.

We have the following steps to implement the solution:

  1. Iterate over each number in the input array nums.
  2. For each number, calculate the count of pairs that fall below high+1 and count of pairs that fall below low.
  3. Subtract the count of pairs below low from the count of pairs below high+1 to get the count of valid pairs for the current number and add that to our answer.
  4. Insert the current number into Trie.

Here's a small illustration:

1nums = [1, 4, 2, 7], low = 2, high = 6
2Trie:
30---1---1
41---0---0---1
5      1---1
6    1
7  1---0
8

Solution Implementation

1struct TrieNode {
2  vector<shared_ptr<TrieNode>> children;
3  int count = 0;
4  TrieNode() : children(2) {}
5};
6
7class Solution {
8 public:
9  int countPairs(vector<int>& nums, int low, int high) {
10    int ans = 0;
11
12    for (const int num : nums) {
13      ans += getCount(num, high + 1) - getCount(num, low);
14      insert(num);
15    }
16
17    return ans;
18  }
19
20 private:
21  static constexpr int kHeight = 14;
22  shared_ptr<TrieNode> root = make_shared<TrieNode>();
23
24  void insert(int num) {
25    shared_ptr<TrieNode> node = root;
26    for (int i = kHeight; i >= 0; --i) {
27      const int bit = num >> i & 1;
28      if (node->children[bit] == nullptr)
29        node->children[bit] = make_shared<TrieNode>();
30      node = node->children[bit];
31      ++node->count;
32    }
33  }
34
35  // # of nums < limit
36  int getCount(int num, int limit) {
37    int count = 0;
38    shared_ptr<TrieNode> node = root;
39    for (int i = kHeight; i >= 0; --i) {
40      const int bit = num >> i & 1;
41      const int bitLimit = limit >> i & 1;
42      if (bitLimit == 1) {
43        if (node->children[bit] != nullptr)
44          count += node->children[bit]->count;
45        node = node->children[bit ^ 1];
46      } else {
47        node = node->children[bit];
48      }
49      if (node == nullptr)
50        break;
51    }
52    return count;
53  }
54};

Now we've implemented a Trie-based solution to efficiently count the number of nice pairs in the given nums array for the given range (low, high).