1885. Count Pairs in Two Arrays


Problem Description

The problem is to determine the number of specific pairs of indices in two arrays, nums1 and nums2. Both arrays have the same length n. A valid pair (i, j) must satisfy two conditions: i < j (meaning the first index must be less than the second one), and nums1[i] + nums1[j] > nums2[i] + nums2[j] (meaning the sum of nums1 elements at these indices is greater than the sum of nums2 elements at the same indices). We need to count how many such pairs exist.

The task is, given the two arrays nums1 and nums2, to return the total count of pairs (i, j) that meet these conditions.

Intuition

To solve this problem, we leverage the fact that if we fix one index and sort all the potential pair values, we can then use binary search to efficiently find how many values meet our condition for each fixed index. Here's a step-by-step explanation of the solution:

  1. First, compute the difference between the elements at the same indices in nums1 and nums2 and store these differences in a new array d. This transformation simplifies the problem, as we're now looking for indices where d[i] + d[j] > 0.

  2. Sort the array d in non-decreasing order. Sorting enables us to use binary search to efficiently find indices satisfying our condition.

  3. For each element d[i] in d, we want to find the count of elements d[j] such that j > i and d[i] + d[j] > 0. We can rewrite this condition to d[j] > -d[i]. Using binary search on the sorted array d, we look for the right-most position to which -d[i] could be inserted while maintaining the sorted order. This gives us the index beyond which all elements of d[j] would result in d[i] + d[j] > 0.

  4. The bisect_right function from Python's bisect module is used for this purpose. For each i, it returns the index beyond which -d[i] would go in the sorted array.

  5. The count of valid j for each i is the number of elements in d beyond the index found in step 4, which is simply n - (index found by bisect_right).

  6. The total count of valid pairs is obtained by summing the count from step 5 for each i.

Using this method, we reduce a potentially O(n^2) problem (checking each pair directly) to O(n log n) due to sorting and binary search for each element.

Learn more about Binary Search and Sorting patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece:

How does merge sort divide the problem into subproblems?

Solution Approach

The implemented solution follows these steps, using a mix of algorithmic techniques and Python-specific functionalities:

  1. Difference Calculation and Store in Array d: The first step is to calculate the difference array d, where each element d[i] is the difference between nums1[i] and nums2[i] for i from 0 to n-1. This subtraction is done using a list comprehension, which is a concise way to create lists in Python.

    1d = [nums1[i] - nums2[i] for i in range(n)]
  2. Sorting the Array d: The array d is then sorted in non-decreasing order. This sorting is crucial as it prepares the array for a binary search operation. The sorted property of d allows us to apply the bisect algorithm effectively.

    1d.sort()
  3. Using Binary Search to Find Count of Valid Pairs: For each element in d, we use the bisect_right function from the bisect module to find the insertion point for -d[i] into d such that the array remains sorted.

    The function bisect_right is a binary search algorithm that returns the index in the sorted list d, where the value -d[i] should be inserted to maintain the sorted order. The lo parameter signifies the start position for the search which in this case is i + 1, ensuring that j > i.

    The subtraction from n gives us the number of elements larger than -d[i], effectively counting how many j indices will satisfy the condition d[i] + d[j] > 0.

    1sum(n - bisect_right(d, -v, lo=i + 1) for i, v in enumerate(d))
  4. Summing the Counts for Each i: The sum operation in the final line adds up the valid pairs count for each value of i. It iterates over the sorted array d and for each element calculates the number of valid pairs (i, j) where i < j and the sum of the original elements from nums1 at these indices is greater than the sum of the elements from nums2 at the same indices, which corresponds to the condition d[i] + d[j] > 0 post-transformation.

This solution uses a combination of algorithmic concepts:

  • Transformation: to simplify the original condition to a more manageable form.
  • Sorting: to prepare data for efficient searching.
  • Binary Search: to reduce the search space for the pairs from O(n) to O(log n), greatly enhancing the overall algorithm efficiency.
  • Prefix Sum: implicit in the adding up of counts for each index, effectively reducing the number of direct comparisons needed.

Returning the sum at the end gives us the desired count of pairs that fulfill the problem’s conditions efficiently.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which of the following is equvalent to O(3*2^n + n^3 + n!+ log n)?

Example Walkthrough

Let's use a small example with the arrays nums1 = [3, -1, 7] and nums2 = [4, 0, 5] to illustrate the solution approach step-by-step. The length of both arrays, n, is 3. Our goal is to find the count of valid pairs (i, j) for which i < j and nums1[i] + nums1[j] > nums2[i] + nums2[j].

Step 1: Calculate difference array d

First, find the difference between corresponding elements of nums1 and nums2:

  • d[0] = nums1[0] - nums2[0] = 3 - 4 = -1
  • d[1] = nums1[1] - nums2[1] = -1 - 0 = -1
  • d[2] = nums1[2] - nums2[2] = 7 - 5 = 2

So the difference array d is [-1, -1, 2].

Step 2: Sort the array d

We sort the array d to get [-1, -1, 2]. In this small case, sorting does not change the order, as the list is already in non-decreasing order.

Step 3: Use binary search for each i

We use bisect_right to find where -d[i] can be inserted:

  • For i = 0 (d[0] = -1): We search for where 1 can be inserted after index 0. bisect_right([-1, -1, 2], 1, lo=0 + 1) = 3.
  • For i = 1 (d[1] = -1): We search for where 1 can be inserted after index 1. bisect_right([-1, -1, 2], 1, lo=1 + 1) = 3.
  • We do not search for i = 2 because it's the last element, and no j can satisfy i < j.

Step 4: Calculate the valid j indices and sum them up

  • For i = 0: The count of valid j indices is n - 3 = 3 - 3 = 0.
  • For i = 1: The count of valid j indices is n - 3 = 3 - 3 = 0.

The total count of valid pairs (i, j) is the sum of the counts above: 0 + 0 = 0.

Therefore, for the given arrays nums1 and nums2, there are no valid pairs (i, j) that meet the condition nums1[i] + nums1[j] > nums2[i] + nums2[j].

Solution Implementation

1from typing import List
2from bisect import bisect_right
3
4class Solution:
5    def countPairs(self, nums1: List[int], nums2: List[int]) -> int:
6        # Length of the input lists
7        length = len(nums1)
8        # Calculate the difference between the two lists element-wise
9        differences = [nums1[i] - nums2[i] for i in range(length)]
10        # Sort the differences to prepare for binary search
11        differences.sort()
12      
13        # Initialize count of pairs to 0
14        count = 0
15        # Loop through the sorted differences list
16        for i, value in enumerate(differences):
17            # For each element, find the number of elements in the sorted
18            # list that would create a negative sum with the current element.
19            # The `bisect_right` function is used to find the position to
20            # insert `-value` which gives the number of such elements.
21            # Subtract this position from the total number of elements that
22            # can be paired with the current element, which is (length - i - 1).
23            # We use `lo=i+1` because we shouldn't pair an element with itself.
24            count += length - bisect_right(differences, -value, lo=i + 1)
25      
26        # Return the total count of valid pairs
27        return count
28
1class Solution {
2    public long countPairs(int[] nums1, int[] nums2) {
3        // Get the length of the arrays
4        int n = nums1.length;
5
6        // Create a new array to store the differences between nums1 and nums2
7        int[] differences = new int[n];
8        for (int i = 0; i < n; ++i) {
9            differences[i] = nums1[i] - nums2[i];
10        }
11
12        // Sort the array of differences
13        Arrays.sort(differences);
14
15        // Initialize answer to count the valid pairs
16        long answer = 0;
17
18        // Iterate over each element in the differences array
19        for (int i = 0; i < n; ++i) {
20            // Use binary search to find the number of valid pairs
21            int left = i + 1, right = n;
22            while (left < right) {
23                int mid = (left + right) / 2;
24                // Check if this position contributes to a valid pair
25                if (differences[mid] > -differences[i]) {
26                    right = mid;
27                } else {
28                    left = mid + 1;
29                }
30            }
31            // Add the count of valid pairs for this position to the answer
32            answer += n - left;
33        }
34
35        // Return the total number of valid pairs
36        return answer;
37    }
38}
39
1class Solution {
2public:
3    long long countPairs(vector<int>& nums1, vector<int>& nums2) {
4        // Get the size of the input vectors
5        int size = nums1.size();
6      
7        // Create a difference vector to store differences of nums1[i] - nums2[i]
8        vector<int> diff(size);
9      
10        // Populate the difference vector
11        for (int i = 0; i < size; ++i) {
12            diff[i] = nums1[i] - nums2[i];
13        }
14      
15        // Sort the difference vector in non-decreasing order
16        sort(diff.begin(), diff.end());
17      
18        // Initialize result variable to store the final count of pairs
19        long long result = 0;
20      
21        // Iterate through each element in the difference vector
22        for (int i = 0; i < size; ++i) {
23            // Find the index of the first element that is greater than -diff[i]
24            // This is done to ensure that for any pair (i, j), diff[i] + diff[j] > 0
25            int j = upper_bound(diff.begin() + i + 1, diff.end(), -diff[i]) - diff.begin();
26          
27            // Increment the result by the number of valid pairs with the current element at index i
28            result += size - j;
29        }
30      
31        // Return the computed number of valid pairs
32        return result;
33    }
34};
35
1function countPairs(nums1: number[], nums2: number[]): bigint {
2    // Get the size of the input arrays
3    const size: number = nums1.length;
4  
5    // Create a difference array to store differences of nums1[i] - nums2[i]
6    let diff: number[] = new Array<number>(size);
7  
8    // Populate the difference array
9    for (let i = 0; i < size; i++) {
10        diff[i] = nums1[i] - nums2[i];
11    }
12  
13    // Sort the difference array in non-decreasing order (ascending)
14    diff.sort((a, b) => a - b);
15  
16    // Initialize result variable to store the final count of pairs
17    let result: bigint = BigInt(0);
18  
19    // Iterate through each element in the difference array
20    for (let i = 0; i < size; i++) {
21        // Find the index of the first element that is strictly greater than -diff[i]
22        // This is done to ensure that for any pair (i, j), diff[i] + diff[j] > 0
23        let j: number = findUpperBound(diff, i + 1, size, -diff[i]);
24      
25        // Increment the result by the number of valid pairs with the current element at index i
26        result += BigInt(size - j);
27    }
28  
29    // Return the computed number of valid pairs
30    return result;
31}
32
33function findUpperBound(arr: number[], start: number, end: number, value: number): number {
34    // Binary search for the first element in the sorted array that is strictly greater than the given value
35    let low: number = start;
36    let high: number = end;
37  
38    while (low < high) {
39        const mid: number = low + Math.floor((high - low) / 2);
40        if (arr[mid] <= value) {
41            low = mid + 1;
42        } else {
43            high = mid;
44        }
45    }
46  
47    // Return the index where the value would be inserted (first index greater than the value)
48    return low;
49}
50
Not Sure What to Study? Take the 2-min Quiz:

Which type of traversal does breadth first search do?

Time and Space Complexity

Time Complexity

The time complexity of the code can be broken down into several steps:

  1. Create a difference list d by subtracting nums2 from nums1. This step is O(n) where n is the length of the input lists.
  2. Sort the difference list d. Sorting algorithms generally have a time complexity of O(n log n).
  3. For each element in d, perform a binary search using bisect_right. Since we perform a binary search (O(log n)) for each element in the list, this step has a time complexity of O(n log n).

Adding these up, the overall time complexity is dominated by the sorting and binary search steps, which leads to O(n log n).

Space Complexity

The space complexity is evaluated as follows:

  1. We are creating a difference list d of size n, therefore requiring O(n) additional space.
  2. Sorting the list in-place (as Python’s sort does) has a space complexity of O(log n), as typical implementations of sorting algorithms, like Timsort (used by Python's sort method), use O(log n) space.
  3. The binary search itself does not use additional space (aside from a few pointers), so the space used remains O(log n).

As the additional space required for the difference list is the largest contributor, the overall space complexity is O(n).

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

What are the two properties the problem needs to have for dynamic programming to be applicable? (Select 2)


Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫