Facebook Pixel

2964. Number of Divisible Triplet Sums 🔒

MediumArrayHash Table
Leetcode Link

Problem Description

You are given a 0-indexed integer array nums and an integer d. Your task is to find the number of triplets (i, j, k) that satisfy the following conditions:

  1. The indices must follow the order: i < j < k
  2. The sum of the three elements at these indices must be divisible by d, meaning (nums[i] + nums[j] + nums[k]) % d == 0

The solution uses a hash table to track remainders and iterates through the array. For each position j, it:

  • Iterates through all positions k where k > j
  • Calculates what remainder value x would be needed from a previous element nums[i] (where i < j) to make the sum divisible by d
  • The required remainder is calculated as x = (d - (nums[j] + nums[k]) % d) % d
  • Adds the count of previously seen elements with remainder x to the answer
  • After checking all k values for a given j, records nums[j] % d in the hash table for future iterations

This approach ensures that we only count valid triplets where i < j < k by maintaining the hash table of remainders seen so far (elements before j) and checking elements after j (at position k).

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is recognizing that for three numbers to sum to a value divisible by d, we need (nums[i] + nums[j] + nums[k]) % d == 0. This can be rewritten as: we need to find nums[i] such that nums[i] % d equals a specific value that complements (nums[j] + nums[k]) % d to make the total divisible by d.

Instead of checking all possible triplets with three nested loops (which would be O(n³)), we can be smarter about it. The trick is to fix the middle element j and think about what happens:

  • For elements to the right of j (potential k values), we can iterate through them
  • For elements to the left of j (potential i values), we don't want to iterate through them again for each (j, k) pair

This leads us to the idea of preprocessing: as we move through the array with j, we can maintain information about all elements we've seen so far (which are valid i candidates). Specifically, we only care about the remainder when each element is divided by d, since that's what determines divisibility.

For a given pair (j, k), we know their sum modulo d is (nums[j] + nums[k]) % d. To make the total sum divisible by d, we need the third element nums[i] to have a remainder that "cancels out" this value. If (nums[j] + nums[k]) % d = r, then we need nums[i] % d = (d - r) % d.

By maintaining a hash table that counts how many elements we've seen with each possible remainder value, we can instantly know how many valid i values exist for any (j, k) pair. This reduces our time complexity from O(n³) to O(n²).

Solution Approach

The solution uses a hash table combined with enumeration to efficiently count valid triplets.

Data Structure:

  • cnt: A hash table (defaultdict) that stores the count of remainders. Specifically, cnt[r] represents how many elements we've processed so far that have remainder r when divided by d.

Algorithm Steps:

  1. Initialize an empty hash table cnt and set ans = 0 to track the total count of valid triplets.

  2. Iterate through the array with index j from 0 to n-1:

    • For each j, this element will serve as the middle element of our triplet
  3. For each fixed j, iterate through all elements after it with index k from j+1 to n-1:

    • Calculate the sum (nums[j] + nums[k]) % d
    • Determine what remainder we need from nums[i] to make the total sum divisible by d:
      • x = (d - (nums[j] + nums[k]) % d) % d
      • The outer modulo operation handles the case when (nums[j] + nums[k]) % d = 0, ensuring x = 0 instead of x = d
    • Add cnt[x] to the answer, which represents how many valid i indices exist for this (j, k) pair
  4. After processing all k values for the current j, update the hash table:

    • Add nums[j] % d to the hash table by incrementing cnt[nums[j] % d] += 1
    • This makes nums[j] available as a potential i value for future iterations

Why this works:

  • When we're at position j, all elements before j (indices 0 to j-1) have already been recorded in the hash table
  • These are exactly the valid candidates for index i since we need i < j
  • By checking all k > j and looking up the required remainder in our hash table, we count all valid triplets with j as the middle element
  • The order i < j < k is naturally maintained by this approach

Time Complexity: O(n²) - two nested loops Space Complexity: O(d) - the hash table can store at most d different remainders

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through the solution with a concrete example:

  • nums = [3, 3, 4, 7, 8]
  • d = 5

We need to find triplets (i, j, k) where i < j < k and (nums[i] + nums[j] + nums[k]) % 5 == 0.

Initial State:

  • cnt = {} (empty hash table)
  • ans = 0

Iteration 1: j = 0 (nums[j] = 3)

  • Check all k > 0:
    • k = 1: nums[k] = 3
      • Sum needed: (3 + 3) % 5 = 6 % 5 = 1
      • Required remainder: x = (5 - 1) % 5 = 4
      • cnt[4] = 0 (no elements with remainder 4 seen yet)
      • ans += 0
    • k = 2: nums[k] = 4
      • Sum needed: (3 + 4) % 5 = 7 % 5 = 2
      • Required remainder: x = (5 - 2) % 5 = 3
      • cnt[3] = 0 (no elements with remainder 3 seen yet)
      • ans += 0
    • k = 3: nums[k] = 7
      • Sum needed: (3 + 7) % 5 = 10 % 5 = 0
      • Required remainder: x = (5 - 0) % 5 = 0
      • cnt[0] = 0 (no elements with remainder 0 seen yet)
      • ans += 0
    • k = 4: nums[k] = 8
      • Sum needed: (3 + 8) % 5 = 11 % 5 = 1
      • Required remainder: x = (5 - 1) % 5 = 4
      • cnt[4] = 0
      • ans += 0
  • Update hash table: 3 % 5 = 3, so cnt[3] = 1
  • Current state: cnt = {3: 1}, ans = 0

Iteration 2: j = 1 (nums[j] = 3)

  • Check all k > 1:
    • k = 2: nums[k] = 4
      • Sum needed: (3 + 4) % 5 = 2
      • Required remainder: x = (5 - 2) % 5 = 3
      • cnt[3] = 1 (we have one element with remainder 3: nums[0])
      • ans += 1 (found triplet: i=0, j=1, k=2)
    • k = 3: nums[k] = 7
      • Sum needed: (3 + 7) % 5 = 0
      • Required remainder: x = (5 - 0) % 5 = 0
      • cnt[0] = 0
      • ans += 0
    • k = 4: nums[k] = 8
      • Sum needed: (3 + 8) % 5 = 1
      • Required remainder: x = (5 - 1) % 5 = 4
      • cnt[4] = 0
      • ans += 0
  • Update hash table: 3 % 5 = 3, so cnt[3] = 2
  • Current state: cnt = {3: 2}, ans = 1

Iteration 3: j = 2 (nums[j] = 4)

  • Check all k > 2:
    • k = 3: nums[k] = 7
      • Sum needed: (4 + 7) % 5 = 11 % 5 = 1
      • Required remainder: x = (5 - 1) % 5 = 4
      • cnt[4] = 0
      • ans += 0
    • k = 4: nums[k] = 8
      • Sum needed: (4 + 8) % 5 = 12 % 5 = 2
      • Required remainder: x = (5 - 2) % 5 = 3
      • cnt[3] = 2 (we have two elements with remainder 3: nums[0] and nums[1])
      • ans += 2 (found triplets: i=0, j=2, k=4 and i=1, j=2, k=4)
  • Update hash table: 4 % 5 = 4, so cnt[4] = 1
  • Current state: cnt = {3: 2, 4: 1}, ans = 3

Iteration 4: j = 3 (nums[j] = 7)

  • Check all k > 3:
    • k = 4: nums[k] = 8
      • Sum needed: (7 + 8) % 5 = 15 % 5 = 0
      • Required remainder: x = (5 - 0) % 5 = 0
      • cnt[0] = 0
      • ans += 0
  • Update hash table: 7 % 5 = 2, so cnt[2] = 1
  • Current state: cnt = {3: 2, 4: 1, 2: 1}, ans = 3

Iteration 5: j = 4 (nums[j] = 8)

  • No k > 4 exists, so no pairs to check
  • Update hash table: 8 % 5 = 3, so cnt[3] = 3

Final Answer: 3

The three valid triplets are:

  1. (0, 1, 2): nums[0] + nums[1] + nums[2] = 3 + 3 + 4 = 10, which is divisible by 5
  2. (0, 2, 4): nums[0] + nums[2] + nums[4] = 3 + 4 + 8 = 15, which is divisible by 5
  3. (1, 2, 4): nums[1] + nums[2] + nums[4] = 3 + 4 + 8 = 15, which is divisible by 5

Solution Implementation

1from typing import List
2from collections import defaultdict
3
4class Solution:
5    def divisibleTripletCount(self, nums: List[int], d: int) -> int:
6        # Dictionary to count occurrences of each remainder when divided by d
7        # Key: remainder value, Value: count of elements with that remainder seen so far
8        remainder_count = defaultdict(int)
9      
10        # Initialize result counter and get array length
11        result = 0
12        n = len(nums)
13      
14        # Iterate through middle element j of potential triplets (i, j, k)
15        for j in range(n):
16            # For each k > j, check how many valid i < j exist
17            for k in range(j + 1, n):
18                # Calculate the required remainder for nums[i] to make the triplet sum divisible by d
19                # If (nums[i] + nums[j] + nums[k]) % d == 0, then nums[i] % d must equal required_remainder
20                required_remainder = (d - (nums[j] + nums[k]) % d) % d
21              
22                # Add count of all valid i values (those before j with the required remainder)
23                result += remainder_count[required_remainder]
24          
25            # After processing all k values for current j, add nums[j] to our remainder count
26            # This ensures it can be used as a potential i value for future j positions
27            remainder_count[nums[j] % d] += 1
28      
29        return result
30
1class Solution {
2    public int divisibleTripletCount(int[] nums, int d) {
3        // Map to store frequency of remainders (when divided by d) for elements before index j
4        Map<Integer, Integer> remainderCount = new HashMap<>();
5        int totalTriplets = 0;
6        int arrayLength = nums.length;
7      
8        // Iterate through each possible middle element of the triplet (index j)
9        for (int j = 0; j < arrayLength; ++j) {
10            // For current j, check all possible third elements (index k where k > j)
11            for (int k = j + 1; k < arrayLength; ++k) {
12                // Calculate the remainder needed from nums[i] to make the sum divisible by d
13                // If (nums[i] + nums[j] + nums[k]) % d == 0, then
14                // nums[i] % d == (d - (nums[j] + nums[k]) % d) % d
15                int requiredRemainder = (d - (nums[j] + nums[k]) % d) % d;
16              
17                // Add count of all elements before j that have the required remainder
18                totalTriplets += remainderCount.getOrDefault(requiredRemainder, 0);
19            }
20          
21            // After processing all k for current j, add nums[j] to the remainder count map
22            // This ensures nums[j] can be used as nums[i] for future iterations
23            remainderCount.merge(nums[j] % d, 1, Integer::sum);
24        }
25      
26        return totalTriplets;
27    }
28}
29
1class Solution {
2public:
3    int divisibleTripletCount(vector<int>& nums, int d) {
4        // Map to store frequency of remainders when divided by d
5        // Key: remainder value, Value: count of elements with that remainder
6        unordered_map<int, int> remainderCount;
7      
8        int result = 0;
9        int n = nums.size();
10      
11        // Iterate through array considering each element as middle element (j)
12        for (int j = 0; j < n; ++j) {
13            // For current j, check all possible k values (k > j)
14            for (int k = j + 1; k < n; ++k) {
15                // Calculate required remainder for nums[i] to make triplet divisible by d
16                // We need: (nums[i] + nums[j] + nums[k]) % d == 0
17                // So: nums[i] % d == (d - (nums[j] + nums[k]) % d) % d
18                int requiredRemainder = (d - (nums[j] + nums[k]) % d) % d;
19              
20                // Add count of all previous elements with the required remainder
21                // These will form valid triplets (i, j, k) where i < j < k
22                result += remainderCount[requiredRemainder];
23            }
24          
25            // After processing all k values for current j,
26            // add nums[j] to the map for future iterations
27            remainderCount[nums[j] % d]++;
28        }
29      
30        return result;
31    }
32};
33
1/**
2 * Counts the number of triplets (i, j, k) where i < j < k 
3 * such that (nums[i] + nums[j] + nums[k]) is divisible by d
4 * 
5 * @param nums - Array of integers
6 * @param d - The divisor to check divisibility against
7 * @returns Number of valid triplets
8 */
9function divisibleTripletCount(nums: number[], d: number): number {
10    const arrayLength: number = nums.length;
11    // Map to store frequency of remainders when divided by d
12    const remainderFrequency: Map<number, number> = new Map<number, number>();
13    let tripletCount: number = 0;
14  
15    // Iterate through array considering element at index j as middle element
16    for (let j = 0; j < arrayLength; ++j) {
17        // For each j, check all possible k values where k > j
18        for (let k = j + 1; k < arrayLength; ++k) {
19            // Calculate required remainder for nums[i] to make the sum divisible by d
20            // If (nums[i] + nums[j] + nums[k]) % d = 0, then
21            // nums[i] % d must equal (d - (nums[j] + nums[k]) % d) % d
22            const requiredRemainder: number = (d - ((nums[j] + nums[k]) % d)) % d;
23          
24            // Add count of all previous elements with the required remainder
25            tripletCount += remainderFrequency.get(requiredRemainder) || 0;
26        }
27      
28        // After processing all k values for current j, 
29        // add nums[j] to the map for future iterations
30        const currentRemainder: number = nums[j] % d;
31        remainderFrequency.set(currentRemainder, (remainderFrequency.get(currentRemainder) || 0) + 1);
32    }
33  
34    return tripletCount;
35}
36

Time and Space Complexity

Time Complexity: O(n²)

The code contains two nested loops:

  • The outer loop iterates through j from 0 to n-1, running n times
  • The inner loop iterates through k from j+1 to n-1, running approximately n-j-1 times for each j
  • The total iterations across both loops is (n-1) + (n-2) + ... + 1 + 0 = n(n-1)/2, which simplifies to O(n²)
  • Inside the inner loop, all operations (modulo calculation, dictionary lookup, and addition) are O(1)
  • After the inner loop, the dictionary update operation is also O(1)

Therefore, the overall time complexity is O(n²).

Space Complexity: O(n)

The space usage comes from:

  • The cnt dictionary (defaultdict) which stores the count of remainders when elements are divided by d
  • Since d is a constant and each element's remainder when divided by d can only be in the range [0, d-1], the dictionary can have at most min(n, d) unique keys
  • In the worst case where all elements have different remainders or d ≥ n, the dictionary can store up to n entries
  • Other variables (ans, n, j, k, x) use constant space O(1)

Therefore, the overall space complexity is O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrect Remainder Calculation

One of the most common mistakes is incorrectly calculating the required remainder for nums[i].

Pitfall Code:

# WRONG: This fails when (nums[j] + nums[k]) % d == 0
required_remainder = d - (nums[j] + nums[k]) % d

Issue: When (nums[j] + nums[k]) % d == 0, this formula gives required_remainder = d, but we actually need required_remainder = 0 since we're looking for elements with remainder 0.

Solution:

# CORRECT: The outer modulo ensures we get 0 instead of d
required_remainder = (d - (nums[j] + nums[k]) % d) % d

2. Updating Hash Table at Wrong Time

Another critical mistake is updating the hash table at the wrong point in the iteration.

Pitfall Code:

for j in range(n):
    # WRONG: Adding nums[j] to hash table BEFORE checking k values
    remainder_count[nums[j] % d] += 1
  
    for k in range(j + 1, n):
        required_remainder = (d - (nums[j] + nums[k]) % d) % d
        result += remainder_count[required_remainder]

Issue: This would count invalid triplets where i == j, violating the constraint i < j < k.

Solution:

for j in range(n):
    for k in range(j + 1, n):
        required_remainder = (d - (nums[j] + nums[k]) % d) % d
        result += remainder_count[required_remainder]
  
    # CORRECT: Add nums[j] AFTER processing all k values
    remainder_count[nums[j] % d] += 1

3. Using Regular Dictionary Instead of DefaultDict

Using a regular dictionary without proper initialization can cause KeyError.

Pitfall Code:

remainder_count = {}  # Regular dictionary
# ...
result += remainder_count[required_remainder]  # KeyError if key doesn't exist

Solution:

# Option 1: Use defaultdict
from collections import defaultdict
remainder_count = defaultdict(int)

# Option 2: Use get() method with default value
remainder_count = {}
result += remainder_count.get(required_remainder, 0)

4. Handling Negative Numbers Incorrectly

If the array contains negative numbers, the modulo operation in Python handles them correctly, but it's important to understand the behavior.

Example: For d = 5 and nums[j] + nums[k] = -3:

  • Python's (-3) % 5 returns 2 (not -3)
  • The required remainder calculation still works correctly

Best Practice: Always test with negative numbers to ensure your solution handles them properly.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

You are given an array of intervals where intervals[i] = [start_i, end_i] represent the start and end of the ith interval. You need to merge all overlapping intervals and return an array of the non-overlapping intervals that cover all the intervals in the input.


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More