Facebook Pixel

2040. Kth Smallest Product of Two Sorted Arrays

Problem Description

You are given two sorted integer arrays nums1 and nums2 (both 0-indexed), and an integer k. Your task is to find the k-th smallest product among all possible products formed by multiplying one element from nums1 with one element from nums2.

Specifically, you need to consider all products of the form nums1[i] * nums2[j] where 0 <= i < nums1.length and 0 <= j < nums2.length. After calculating all these products, sort them in ascending order and return the k-th value (using 1-based indexing, meaning k=1 refers to the smallest product).

For example, if nums1 = [2, 5] and nums2 = [3, 4], the possible products are:

  • 2 * 3 = 6
  • 2 * 4 = 8
  • 5 * 3 = 15
  • 5 * 4 = 20

Sorted in ascending order: [6, 8, 15, 20]. If k = 2, the answer would be 8.

Note that the arrays can contain negative numbers, zero, or positive numbers, and the arrays are already sorted in non-decreasing order.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The naive approach would be to generate all possible products nums1[i] * nums2[j], sort them, and return the k-th element. However, with arrays of size up to 10^5, this would create up to 10^10 products, which is too memory-intensive and time-consuming.

The key insight is that we don't need to generate all products explicitly. Instead, we can use binary search on the answer space. If we pick a candidate value p, we can efficiently count how many products are less than or equal to p without generating all products.

Why binary search? Because if we can count how many products are <= p for any given p, we can find the smallest p such that exactly k products are <= p. This p would be our k-th smallest product.

The challenge becomes: given a value p, how do we count products <= p efficiently?

For each element x in nums1, we need to count how many elements y in nums2 satisfy x * y <= p. This depends on the sign of x:

  • If x > 0: We need y <= p/x. Since nums2 is sorted and multiplication by positive x preserves order, we can use binary search to find the rightmost position where nums2[j] <= p/x.

  • If x < 0: We need y >= p/x (inequality flips when dividing by negative). Since nums2 is sorted and multiplication by negative x reverses order, larger values of y give smaller products. We can use binary search to find the leftmost position where nums2[j] >= p/x.

  • If x = 0: The product is always 0. If p >= 0, all elements in nums2 contribute to the count.

By efficiently counting products for each candidate p, we can binary search for the exact k-th smallest product value. The search range is bounded by the maximum possible product magnitude, which is the product of the maximum absolute values from both arrays.

Learn more about Binary Search patterns.

Solution Approach

The solution uses binary search on the answer space combined with a counting function. Here's the step-by-step implementation:

1. Define the Search Range: First, we calculate the maximum possible absolute product value:

mx = max(abs(nums1[0]), abs(nums1[-1])) * max(abs(nums2[0]), abs(nums2[-1]))

Since the arrays are sorted, the extreme values are at the endpoints. The search range becomes [-mx, mx].

2. Implement the Counting Function: The count(p) function calculates how many products are less than or equal to p:

def count(p: int) -> int:
    cnt = 0
    n = len(nums2)
    for x in nums1:
        if x > 0:
            cnt += bisect_right(nums2, p / x)
        elif x < 0:
            cnt += n - bisect_left(nums2, p / x)
        else:
            cnt += n * int(p >= 0)
    return cnt

For each element x in nums1:

  • When x > 0: We find how many elements in nums2 satisfy y <= p/x. Using bisect_right(nums2, p/x) gives us the count of such elements since it returns the insertion position after all elements <= p/x.

  • When x < 0: We need elements where y >= p/x (inequality flips). Using bisect_left(nums2, p/x) finds the first position where nums2[j] >= p/x. Therefore, n - bisect_left(nums2, p/x) gives the count of valid elements.

  • When x = 0: All products equal 0. If p >= 0, all n elements of nums2 produce valid products; otherwise, none do.

3. Binary Search for the k-th Smallest: The main binary search uses Python's bisect_left with a custom key function:

return bisect_left(range(-mx, mx + 1), k, key=count) - mx

This searches for the leftmost position in the range [-mx, mx] where count(p) >= k. The key=count parameter transforms each value p in the range to count(p) before comparison.

The function effectively finds the smallest p such that at least k products are <= p. Since we're searching in range(-mx, mx + 1) which starts from 0, we subtract mx from the result to get the actual value.

Time Complexity: O(m * log(n) * log(P)) where m = len(nums1), n = len(nums2), and P is the range of possible products. The outer binary search takes O(log(P)) iterations, and each iteration calls count() which takes O(m * log(n)) time.

Space Complexity: O(1) as we only use constant extra space beyond the input arrays.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach.

Example: nums1 = [-2, 1, 3], nums2 = [-1, 2, 4], k = 5

Step 1: Calculate all possible products (for understanding) Let's first see what all products look like:

  • -2 * -1 = 2
  • -2 * 2 = -4
  • -2 * 4 = -8
  • 1 * -1 = -1
  • 1 * 2 = 2
  • 1 * 4 = 4
  • 3 * -1 = -3
  • 3 * 2 = 6
  • 3 * 4 = 12

Sorted: [-8, -4, -3, -1, 2, 2, 4, 6, 12] The 5th smallest is 2.

Step 2: Apply our binary search solution

Calculate search range:

  • mx = max(|-2|, |3|) * max(|-1|, |4|) = 3 * 4 = 12
  • Search range: [-12, 12]

Step 3: Binary search process

We'll binary search for the smallest value p where count(p) >= 5.

Let's trace through some key iterations:

When p = 0:

  • For x = -2 (negative): Need y >= 0/-2 = 0. Elements in nums2 that are >= 0: [2, 4]. Count = 2
  • For x = 1 (positive): Need y <= 0/1 = 0. Elements in nums2 that are <= 0: [-1]. Count = 1
  • For x = 3 (positive): Need y <= 0/3 = 0. Elements in nums2 that are <= 0: [-1]. Count = 1
  • Total count(0) = 2 + 1 + 1 = 4 (less than k=5)

When p = 2:

  • For x = -2 (negative): Need y >= 2/-2 = -1. Elements in nums2 that are >= -1: [-1, 2, 4]. Count = 3
  • For x = 1 (positive): Need y <= 2/1 = 2. Elements in nums2 that are <= 2: [-1, 2]. Count = 2
  • For x = 3 (positive): Need y <= 2/3 = 0.67. Elements in nums2 that are <= 0.67: [-1]. Count = 1
  • Total count(2) = 3 + 2 + 1 = 6 (greater than or equal to k=5)

When p = 1:

  • For x = -2: Need y >= -0.5. Elements >= -0.5: [2, 4]. Count = 2
  • For x = 1: Need y <= 1. Elements <= 1: [-1]. Count = 1
  • For x = 3: Need y <= 0.33. Elements <= 0.33: [-1]. Count = 1
  • Total count(1) = 2 + 1 + 1 = 4 (less than k=5)

The binary search finds that p = 2 is the smallest value where count(p) >= 5, which matches our expected answer.

Key Insight: We never actually computed all 9 products. Instead, we efficiently counted how many products were less than or equal to candidate values, allowing us to pinpoint the k-th smallest product through binary search.

Solution Implementation

1class Solution:
2    def kthSmallestProduct(self, nums1: List[int], nums2: List[int], k: int) -> int:
3        def count_products_less_than_or_equal(target_product: int) -> int:
4            """
5            Count how many products from nums1[i] * nums2[j] are <= target_product
6            """
7            total_count = 0
8            nums2_length = len(nums2)
9          
10            for num1 in nums1:
11                if num1 > 0:
12                    # For positive num1, find how many nums2[j] satisfy num1 * nums2[j] <= target_product
13                    # This means nums2[j] <= target_product / num1
14                    total_count += bisect_right(nums2, target_product / num1)
15                elif num1 < 0:
16                    # For negative num1, find how many nums2[j] satisfy num1 * nums2[j] <= target_product
17                    # This means nums2[j] >= target_product / num1 (inequality flips for negative)
18                    total_count += nums2_length - bisect_left(nums2, target_product / num1)
19                else:
20                    # For num1 = 0, product is always 0
21                    # Count all of nums2 if target_product >= 0, otherwise count 0
22                    total_count += nums2_length * int(target_product >= 0)
23          
24            return total_count
25      
26        # Calculate the maximum possible absolute value of products
27        # This will be our search range boundary
28        max_absolute_value = max(abs(nums1[0]), abs(nums1[-1])) * max(abs(nums2[0]), abs(nums2[-1]))
29      
30        # Binary search for the k-th smallest product in range [-max_absolute_value, max_absolute_value]
31        # bisect_left finds the leftmost position where count >= k
32        # Subtract max_absolute_value to convert from index to actual value
33        return bisect_left(range(-max_absolute_value, max_absolute_value + 1), k, 
34                          key=count_products_less_than_or_equal) - max_absolute_value
35
1class Solution {
2    private int[] nums1;
3    private int[] nums2;
4
5    /**
6     * Find the k-th smallest product from all possible products of nums1[i] * nums2[j]
7     * @param nums1 First sorted array
8     * @param nums2 Second sorted array
9     * @param k The k-th position (1-indexed)
10     * @return The k-th smallest product
11     */
12    public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {
13        this.nums1 = nums1;
14        this.nums2 = nums2;
15      
16        int m = nums1.length;
17        int n = nums2.length;
18      
19        // Find the maximum absolute values in both arrays
20        int maxAbsNums1 = Math.max(Math.abs(nums1[0]), Math.abs(nums1[m - 1]));
21        int maxAbsNums2 = Math.max(Math.abs(nums2[0]), Math.abs(nums2[n - 1]));
22      
23        // Set binary search boundaries for possible products
24        long right = (long) maxAbsNums1 * maxAbsNums2;  // Maximum possible product
25        long left = -right;  // Minimum possible product (negative)
26      
27        // Binary search for the k-th smallest product
28        while (left < right) {
29            long mid = (left + right) >> 1;  // Equivalent to (left + right) / 2
30          
31            // Count how many products are <= mid
32            if (count(mid) >= k) {
33                // If count >= k, the answer is in [left, mid]
34                right = mid;
35            } else {
36                // If count < k, the answer is in [mid + 1, right]
37                left = mid + 1;
38            }
39        }
40      
41        return left;
42    }
43
44    /**
45     * Count how many products from nums1[i] * nums2[j] are less than or equal to threshold
46     * @param threshold The upper bound for counting products
47     * @return Number of products <= threshold
48     */
49    private long count(long threshold) {
50        long count = 0;
51        int n = nums2.length;
52      
53        // For each element in nums1, count valid products with nums2
54        for (int x : nums1) {
55            if (x > 0) {
56                // For positive x, find largest index where x * nums2[index] <= threshold
57                int left = 0;
58                int right = n;
59              
60                while (left < right) {
61                    int mid = (left + right) >> 1;
62                  
63                    if ((long) x * nums2[mid] > threshold) {
64                        // Product too large, search in left half
65                        right = mid;
66                    } else {
67                        // Product valid, search in right half for larger index
68                        left = mid + 1;
69                    }
70                }
71              
72                // All products from index 0 to left-1 are valid
73                count += left;
74              
75            } else if (x < 0) {
76                // For negative x, find smallest index where x * nums2[index] > threshold
77                int left = 0;
78                int right = n;
79              
80                while (left < right) {
81                    int mid = (left + right) >> 1;
82                  
83                    if ((long) x * nums2[mid] <= threshold) {
84                        // Product valid, search in left half for smaller index
85                        right = mid;
86                    } else {
87                        // Product too small (more negative), search in right half
88                        left = mid + 1;
89                    }
90                }
91              
92                // All products from index left to n-1 are valid
93                count += n - left;
94              
95            } else {
96                // x == 0, product is always 0
97                if (threshold >= 0) {
98                    // All products with 0 are valid
99                    count += n;
100                }
101            }
102        }
103      
104        return count;
105    }
106}
107
1class Solution {
2public:
3    long long kthSmallestProduct(vector<int>& nums1, vector<int>& nums2, long long k) {
4        int m = nums1.size();
5        int n = nums2.size();
6      
7        // Calculate the range for binary search
8        // The maximum absolute value from nums1
9        int maxAbsNums1 = max(abs(nums1[0]), abs(nums1[m - 1]));
10        // The maximum absolute value from nums2
11        int maxAbsNums2 = max(abs(nums2[0]), abs(nums2[n - 1]));
12        // The maximum possible product value
13        long long maxProduct = 1LL * maxAbsNums1 * maxAbsNums2;
14        // Search range: [-maxProduct, maxProduct]
15        long long left = -maxProduct;
16        long long right = maxProduct;
17      
18        // Lambda function to count how many products are <= threshold
19        auto countProductsLessOrEqual = [&](long long threshold) {
20            long long count = 0;
21          
22            // For each element in nums1, count valid products with nums2
23            for (int num1 : nums1) {
24                if (num1 > 0) {
25                    // For positive num1, find how many nums2[j] satisfy num1 * nums2[j] <= threshold
26                    // Since num1 > 0, we need nums2[j] <= threshold / num1
27                    int searchLeft = 0;
28                    int searchRight = n;
29                    while (searchLeft < searchRight) {
30                        int mid = (searchLeft + searchRight) >> 1;
31                        if (1LL * num1 * nums2[mid] > threshold) {
32                            searchRight = mid;
33                        } else {
34                            searchLeft = mid + 1;
35                        }
36                    }
37                    count += searchLeft;
38                } else if (num1 < 0) {
39                    // For negative num1, find how many nums2[j] satisfy num1 * nums2[j] <= threshold
40                    // Since num1 < 0, we need nums2[j] >= threshold / num1
41                    // Count elements from the right side
42                    int searchLeft = 0;
43                    int searchRight = n;
44                    while (searchLeft < searchRight) {
45                        int mid = (searchLeft + searchRight) >> 1;
46                        if (1LL * num1 * nums2[mid] <= threshold) {
47                            searchRight = mid;
48                        } else {
49                            searchLeft = mid + 1;
50                        }
51                    }
52                    count += n - searchLeft;
53                } else {
54                    // num1 == 0, product is always 0
55                    // If threshold >= 0, all products with nums2 are valid
56                    if (threshold >= 0) {
57                        count += n;
58                    }
59                }
60            }
61            return count;
62        };
63      
64        // Binary search for the k-th smallest product
65        while (left < right) {
66            long long mid = (left + right) >> 1;
67            // If there are at least k products <= mid, the answer is in [left, mid]
68            if (countProductsLessOrEqual(mid) >= k) {
69                right = mid;
70            } else {
71                // Otherwise, the answer is in [mid + 1, right]
72                left = mid + 1;
73            }
74        }
75      
76        return left;
77    }
78};
79
1/**
2 * Finds the k-th smallest product from all possible products of pairs (nums1[i], nums2[j])
3 * @param nums1 - First sorted array of integers
4 * @param nums2 - Second sorted array of integers  
5 * @param k - The k-th position to find (1-indexed)
6 * @returns The k-th smallest product
7 */
8function kthSmallestProduct(nums1: number[], nums2: number[], k: number): number {
9    const firstArrayLength = nums1.length;
10    const secondArrayLength = nums2.length;
11
12    // Calculate the maximum absolute values to determine search bounds
13    const maxAbsFirstArray = BigInt(Math.max(Math.abs(nums1[0]), Math.abs(nums1[firstArrayLength - 1])));
14    const maxAbsSecondArray = BigInt(Math.max(Math.abs(nums2[0]), Math.abs(nums2[secondArrayLength - 1])));
15
16    // Set binary search bounds for the product value
17    let leftBound = -maxAbsFirstArray * maxAbsSecondArray;
18    let rightBound = maxAbsFirstArray * maxAbsSecondArray;
19
20    /**
21     * Counts how many products are less than or equal to the threshold
22     * @param threshold - The product value threshold to count against
23     * @returns Number of products <= threshold
24     */
25    const countProductsLessThanOrEqual = (threshold: bigint): bigint => {
26        let totalCount = 0n;
27      
28        for (const num1 of nums1) {
29            const bigNum1 = BigInt(num1);
30          
31            if (bigNum1 > 0n) {
32                // For positive num1, find how many nums2[j] satisfy num1 * nums2[j] <= threshold
33                let left = 0;
34                let right = secondArrayLength;
35              
36                while (left < right) {
37                    const mid = (left + right) >> 1;
38                    const product = bigNum1 * BigInt(nums2[mid]);
39                  
40                    if (product > threshold) {
41                        right = mid;
42                    } else {
43                        left = mid + 1;
44                    }
45                }
46                totalCount += BigInt(left);
47              
48            } else if (bigNum1 < 0n) {
49                // For negative num1, find how many nums2[j] satisfy num1 * nums2[j] <= threshold
50                // Note: multiplication by negative reverses inequality
51                let left = 0;
52                let right = secondArrayLength;
53              
54                while (left < right) {
55                    const mid = (left + right) >> 1;
56                    const product = bigNum1 * BigInt(nums2[mid]);
57                  
58                    if (product <= threshold) {
59                        right = mid;
60                    } else {
61                        left = mid + 1;
62                    }
63                }
64                totalCount += BigInt(secondArrayLength - left);
65              
66            } else if (threshold >= 0n) {
67                // If num1 is 0, all products are 0
68                // Count all pairs if threshold is non-negative
69                totalCount += BigInt(secondArrayLength);
70            }
71        }
72      
73        return totalCount;
74    };
75
76    // Binary search for the k-th smallest product
77    while (leftBound < rightBound) {
78        const midValue = (leftBound + rightBound) >> 1n;
79      
80        if (countProductsLessThanOrEqual(midValue) >= BigInt(k)) {
81            rightBound = midValue;
82        } else {
83            leftBound = midValue + 1n;
84        }
85    }
86
87    return Number(leftBound);
88}
89

Time and Space Complexity

Time Complexity: O(m × log n × log M)

The time complexity breaks down as follows:

  • The outer binary search using bisect_left on the range [-mx, mx + 1] performs O(log M) iterations, where M = max(abs(nums1[0]), abs(nums1[-1])) × max(abs(nums2[0]), abs(nums2[-1])) represents the maximum possible absolute value of products.
  • For each binary search iteration, the count function is called, which iterates through all m elements in nums1, taking O(m) time.
  • Within each iteration of count, either bisect_right or bisect_left is called on nums2, which takes O(log n) time where n is the length of nums2.
  • Combining these operations: O(log M) × O(m) × O(log n) = O(m × log n × log M)

Space Complexity: O(1)

The space complexity analysis:

  • The count function uses only a constant amount of extra space for variables cnt, n, and x.
  • The bisect_left and bisect_right functions operate in-place and use O(1) extra space.
  • The range object range(-mx, mx + 1) is a lazy iterator that doesn't create the entire list in memory, using O(1) space.
  • No additional data structures are created that scale with input size.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Floating-Point Precision Issues

The most critical pitfall in this solution is using floating-point division (p / x) when working with potentially large integers. This can lead to precision errors that cause incorrect counting.

Problem Example:

# When p = 10^9 and x = 3
# p / x = 333333333.3333...
# Due to floating-point representation, this might become 333333333.33333334
# This can cause bisect_right/bisect_left to return wrong positions

Solution: Replace floating-point division with integer arithmetic:

def count_products_less_than_or_equal(target_product: int) -> int:
    total_count = 0
    nums2_length = len(nums2)
  
    for num1 in nums1:
        if num1 > 0:
            # Find how many nums2[j] satisfy num1 * nums2[j] <= target_product
            left, right = 0, nums2_length
            while left < right:
                mid = (left + right) // 2
                if num1 * nums2[mid] <= target_product:
                    left = mid + 1
                else:
                    right = mid
            total_count += left
        elif num1 < 0:
            # Find how many nums2[j] satisfy num1 * nums2[j] <= target_product
            left, right = 0, nums2_length
            while left < right:
                mid = (left + right) // 2
                if num1 * nums2[mid] <= target_product:
                    right = mid
                else:
                    left = mid + 1
            total_count += nums2_length - left
        else:
            total_count += nums2_length * int(target_product >= 0)
  
    return total_count

2. Integer Overflow in Product Calculation

When calculating max_absolute_value or products during comparison, the multiplication might overflow in languages with fixed integer sizes (though Python handles arbitrary precision integers).

Problem Example:

# If nums1[-1] = 10^5 and nums2[-1] = 10^5
# Their product is 10^10, which exceeds 32-bit integer range

Solution: For languages with overflow concerns, use long/int64 types or implement overflow-safe multiplication checks:

def safe_multiply(a: int, b: int, limit: int) -> int:
    """Returns min(a * b, limit) without overflow"""
    if a == 0 or b == 0:
        return 0
    if abs(a) > limit // abs(b):
        return limit if (a > 0) == (b > 0) else -limit
    return a * b

3. Incorrect Handling of Mixed Sign Arrays

The initial calculation of max_absolute_value assumes the maximum absolute product comes from endpoint values, but with mixed signs, the actual range might be different.

Problem Example:

# nums1 = [-10, -1, 1, 10]
# nums2 = [-10, -1, 1, 10]
# Maximum product: 10 * 10 = 100
# Minimum product: -10 * 10 = -100
# But if we only check endpoints, we might miss interior combinations

Solution: Calculate the actual minimum and maximum possible products:

# Calculate all possible extreme products
candidates = [
    nums1[0] * nums2[0],
    nums1[0] * nums2[-1],
    nums1[-1] * nums2[0],
    nums1[-1] * nums2[-1]
]
min_product = min(candidates)
max_product = max(candidates)

# Then use range [min_product, max_product] for binary search
return bisect_left(range(min_product, max_product + 1), k, 
                  key=count_products_less_than_or_equal) - min_product

4. Off-by-One Error in Binary Search Range

Using range(-mx, mx + 1) assumes all products fall within [-mx, mx], but the actual minimum product might not be -mx.

Solution: Always calculate the exact minimum and maximum products rather than assuming symmetric ranges around zero.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

What's the output of running the following function using the following tree as input?

1def serialize(root):
2    res = []
3    def dfs(root):
4        if not root:
5            res.append('x')
6            return
7        res.append(root.val)
8        dfs(root.left)
9        dfs(root.right)
10    dfs(root)
11    return ' '.join(res)
12
1import java.util.StringJoiner;
2
3public static String serialize(Node root) {
4    StringJoiner res = new StringJoiner(" ");
5    serializeDFS(root, res);
6    return res.toString();
7}
8
9private static void serializeDFS(Node root, StringJoiner result) {
10    if (root == null) {
11        result.add("x");
12        return;
13    }
14    result.add(Integer.toString(root.val));
15    serializeDFS(root.left, result);
16    serializeDFS(root.right, result);
17}
18
1function serialize(root) {
2    let res = [];
3    serialize_dfs(root, res);
4    return res.join(" ");
5}
6
7function serialize_dfs(root, res) {
8    if (!root) {
9        res.push("x");
10        return;
11    }
12    res.push(root.val);
13    serialize_dfs(root.left, res);
14    serialize_dfs(root.right, res);
15}
16

Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More