Facebook Pixel

3539. Find Sum of Array Product of Magical Sequences

HardBit ManipulationArrayMathDynamic ProgrammingBitmaskCombinatorics
LeetCode ↗

Problem Description

You are given two integers, m and k, along with an integer array nums.

A sequence of integers seq is called magical if it satisfies all of the following conditions:

  • The sequence seq has a size of m (it contains exactly m elements).
  • Every element is a valid index into nums, meaning 0 <= seq[i] < nums.length.
  • When you compute the value 2^seq[0] + 2^seq[1] + ... + 2^seq[m - 1], its binary representation must contain exactly k set bits (bits equal to 1).

For any sequence, its array product is defined as the product of the nums values located at each index in the sequence:

prod(seq) = nums[seq[0]] * nums[seq[1]] * ... * nums[seq[m - 1]]

Your task is to find the sum of the array products over all valid magical sequences.

Because the result can be very large, return the answer modulo 10^9 + 7.

For clarity, a set bit is a bit in the binary representation of a number whose value is 1.

Key things to notice:

  • The sequence has a fixed length m, and each position independently chooses an index from 0 to nums.length - 1. Different positions may choose the same index, so repeated indices are allowed.
  • Since the elements of seq are treated as positions in the sequence, two sequences that contain the same multiset of indices but in different orders are considered distinct sequences. This is why combinatorial counting (choosing which positions take a given index value) is involved.
  • The condition involving k is about the number of set bits after summing the powers of two 2^seq[i]. Because choosing the same index multiple times adds multiple copies of the same power of two, carries can occur when these powers add up, which affects the final count of set bits.
Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

How We Pick the Algorithm

Why Dynamic Programming?

This problem maps to Dynamic Programming through a short path in the full flowchart.

Smallconstraints?yesDP ormemoizationneeded?yesDynamicProgramming

The problem uses DP with memoized state (index, remaining positions, popcount budget, carry) to count weighted assignments, with m up to 30 keeping the carry dimension bounded.

Open in Flowchart

Intuition

The first observation is that a magical sequence is built by assigning each of the m positions an index from nums. Instead of thinking about the order of positions directly, we can think about how many times each index is chosen. Suppose index i is chosen t times. Then those t positions contribute nums[i]^t to the array product, and the number of ways to pick exactly which t positions (out of the remaining positions) get value i is the binomial coefficient C(j, t), where j is the number of positions still unassigned. This naturally turns the problem into processing the indices 0, 1, ..., n-1 one at a time and deciding how many positions go to each index.

The trickier part is the condition on set bits. Each time we pick index i, we add 2^i to a running total. If index i is picked t times, that contributes t * 2^i. The key insight is that we don't need to track the entire huge sum — we only need to know how many set bits its binary representation has. We can process bit positions from low to high, just like adding numbers by hand and handling carries.

Think of it this way: when we are at index i (which corresponds to bit position i), the total contribution to that bit is t (the number of times we chose index i) plus whatever carry came in from the lower bits. Call this combined amount nt = t + st, where st is the incoming carry. The lowest bit of nt, that is nt & 1, determines whether bit i is set in the final number. If it is set, we have used up one of our k required set bits. The rest of nt, namely nt >> 1, becomes the carry that flows into the next higher bit position.

This leads directly to a recursive design dfs(i, j, k, st):

  • i is the current index / bit position we are processing.
  • j is how many positions are still left to assign.
  • k is how many set bits we still need to produce.
  • st is the carry coming into bit position i from the lower bits.

At each step, we try every possible count t from 0 to j for the current index. For each choice we multiply by C(j, t) (ways to place those positions) and by nums[i]^t (their contribution to the product), then recurse with the updated carry (t + st) >> 1 and updated set-bit budget k - ((t + st) & 1).

When we finish processing all indices (i == n), the leftover carry st may still contain set bits at higher positions. We count those remaining set bits and subtract them from k. A solution is valid only when all positions are used (j == 0) and the set-bit count lands exactly at k == 0.

Finally, since binomial coefficients are needed repeatedly, we precompute factorials f and their modular inverses g so that C(m, n) = f[m] * g[n] * g[m-n] % mod can be evaluated in constant time. Combining the carry-aware bit counting with combinatorial position assignment gives us an efficient memoized search.

Pattern Learn more about Math, Dynamic Programming, Bitmask and Combinatorics patterns.

Solution Approach

We use combinatorics combined with memoized search.

We design a function dfs(i, j, k, st), which represents the number of weighted ways (weighted by array product) when we are currently processing the i-th element of array nums, still need to select numbers for the remaining j positions to fill into the magical sequence, still need to satisfy having k set bits in binary form, and the current carry coming from the lower bits is st. Then the answer is dfs(0, m, k, 0).

The execution process of function dfs(i, j, k, st) is as follows:

  • If k < 0, or if i == len(nums) and j > 0, it means the current solution is not feasible (we ran out of set-bit budget, or we reached the end without placing all positions), so we return 0.

  • If i == len(nums), it means we have finished processing array nums. We still need to account for any remaining set bits hidden in the carry st. We repeatedly examine the lowest bit of st, subtracting each set bit from k:

    while st:
        k -= st & 1
        st >>= 1

    If k == 0 after this, the solution is feasible and we return 1, otherwise we return 0.

  • Otherwise, we enumerate the count t of positions assigned to index i, where 0 <= t <= j:

    • The number of ways to choose which t of the j remaining positions take value i is the binomial coefficient C(j, t).
    • The product contribution from these t positions is nums[i]^t, computed as pow(nums[i], t, mod).
    • The combined amount at bit position i is nt = t + st. Its lowest bit nt & 1 tells whether bit i is set, so the updated set-bit budget becomes k - (nt & 1).
    • The remainder nt >> 1 becomes the carry flowing into the next bit position.
    • We recurse into dfs(i + 1, j - t, k - (nt & 1), nt >> 1) and accumulate:
      res += comb(j, t) * p * dfs(i + 1, j - t, nk, nt >> 1)
      res %= mod

To compute binomial coefficients efficiently, we preprocess the factorial array f and the inverse factorial array g, where f[i] = i! mod (10^9 + 7) and g[i] = (i!)^{-1} mod (10^9 + 7). The modular inverse of each factorial is obtained via Fermat's little theorem using pow(f[i], mod - 2, mod). Then each combination is evaluated in constant time:

comb(m, n) = f[m] * g[n] * g[m - n] % mod

The search is decorated with @cache so that repeated states (i, j, k, st) are computed only once. After obtaining the answer with dfs(0, m, k, 0), we call dfs.cache_clear() to release the memoization table.

Complexity analysis: Let n = len(nums). The state is determined by i (up to n), j (up to m), k (up to k), and the carry st (bounded by roughly m since at most m powers of two are summed). Each state iterates over up to m + 1 choices of t. Combining these factors, the overall time complexity is approximately O(n * m^2 * k * m) in the worst case, and the space complexity is dominated by the number of distinct memoized states.

Example Walkthrough

Let's use a small concrete example to trace through the solution approach.

Input: m = 2, k = 1, nums = [2, 3]

So n = len(nums) = 2. We want all sequences seq of length 2, where each element is an index into nums (so each is 0 or 1), such that 2^seq[0] + 2^seq[1] has exactly 1 set bit.

Step 1: Enumerate the magical sequences by hand (to verify)

All sequences of length 2 over indices {0, 1}:

seq2^seq[0] + 2^seq[1]binaryset bitsmagical?product nums[seq[0]] * nums[seq[1]]
[0, 0]1 + 1 = 21012 * 2 = 4
[0, 1]1 + 2 = 3112
[1, 0]2 + 1 = 3112
[1, 1]2 + 2 = 410013 * 3 = 9

The magical sequences are [0, 0] and [1, 1]. Notice that choosing the same index twice causes a carry: 1 + 1 = 2 = 10b, collapsing two powers into one higher set bit.

Expected answer: 4 + 9 = 13.

Step 2: Trace dfs(0, m, k, 0) = dfs(0, 2, 1, 0)

We process index 0 (bit position 0). Enumerate t = number of positions assigned to index 0, where 0 <= t <= j = 2. Carry in st = 0.

Branch t = 0 (index 0 chosen 0 times):

  • C(2, 0) = 1, contribution nums[0]^0 = 1.
  • nt = t + st = 0. Bit set? nt & 1 = 0. New k = 1 - 0 = 1. New carry nt >> 1 = 0.
  • Recurse dfs(1, 2, 1, 0).

Branch t = 1 (index 0 chosen once):

  • C(2, 1) = 2, contribution nums[0]^1 = 2.
  • nt = 1. Bit set? nt & 1 = 1. New k = 1 - 1 = 0. New carry 0.
  • Recurse dfs(1, 1, 0, 0), multiplied by 2 * 2 = 4.

Branch t = 2 (index 0 chosen twice):

  • C(2, 2) = 1, contribution nums[0]^2 = 4.
  • nt = 2. Bit set? nt & 1 = 0. New k = 1 - 0 = 1. New carry nt >> 1 = 1 (the carry from 1+1).
  • Recurse dfs(1, 0, 1, 1), multiplied by 1 * 4 = 4.

Step 3: Resolve the sub-calls at index 1 (bit position 1)

dfs(1, 2, 1, 0) (from t=0): still need to place all 2 positions into index 1.

  • t = 2: C(2,2)=1, nums[1]^2 = 9, nt = 2 + 0 = 2, bit = 0, k = 1, carry = 1. Recurse dfs(2, 0, 1, 1).
    • At i == n = 2, j = 0 ✅. Process carry st = 1: lowest bit is 1, so k = 1 - 1 = 0. Now k == 0 ✅ → returns 1. Weighted: 9 * 1 = 9.
  • Other t values (0 or 1) leave positions unfilled or yield wrong bit counts and ultimately return 0.
  • So dfs(1, 2, 1, 0) contributes 9.

dfs(1, 1, 0, 0) (from t=1): one position left, budget k = 0.

  • Any t = 1 here sets bit 1, dropping k below 0 → infeasible. Only t = 0 keeps k = 0, but then j = 1 != 0 at the end → returns 0.
  • So this branch contributes 0. (This correctly kills sequences like [0, 1] / [1, 0].)

dfs(1, 0, 1, 1) (from t=2): no positions left (j = 0), carry st = 1 coming in.

  • Only t = 0 is possible. nt = 0 + 1 = 1, bit = 1, k = 1 - 1 = 0, carry = 0. Recurse dfs(2, 0, 0, 0).
    • At i == n, j = 0 ✅, no carry, k == 0 ✅ → returns 1.
  • So dfs(1, 0, 1, 1) returns 1.

Step 4: Combine everything

dfs(0, 2, 1, 0)
  = (t=0)  1 * 1 * dfs(1, 2, 1, 0) = 1 * 1 * 9 = 9
  + (t=1)  2 * 2 * dfs(1, 1, 0, 0) = 4 * 0     = 0
  + (t=2)  1 * 4 * dfs(1, 0, 1, 1) = 4 * 1     = 4
  = 9 + 0 + 4
  = 13

The result is 13, matching the hand-enumerated answer 4 + 9 = 13. ✅

Key takeaways from the trace

  • The t = 2 branch at index 0 corresponds to sequence [0, 0]: two identical low powers carry up into bit 1, producing exactly one set bit, with product nums[0]^2 = 4.
  • The t = 2 branch at index 1 corresponds to sequence [1, 1]: product nums[1]^2 = 9.
  • The binomial coefficient C(j, t) correctly accounts for distinct orderings — though here the magical solutions happen to use a single repeated index (C(2,2)=1), the C(2,1)=2 factor in the dead branch shows how [0,1] and [1,0] would both be counted if they were valid.
  • The carry (st) is what links bit positions together and is the mechanism that turns "two copies of 2^i" into "one copy of 2^(i+1)", which is exactly why repeated indices change the set-bit count.

Solution Implementation

1mx = 30
2MOD = 10**9 + 7
3
4# Precompute factorials and their modular inverses
5fact = [1] + [0] * mx          # fact[i] = i! % MOD
6inv_fact = [1] + [0] * mx      # inv_fact[i] = (i!)^{-1} % MOD
7
8for i in range(1, mx + 1):
9    fact[i] = fact[i - 1] * i % MOD
10    inv_fact[i] = pow(fact[i], MOD - 2, MOD)  # Fermat's little theorem
11
12
13def comb(n: int, r: int) -> int:
14    """Binomial coefficient C(n, r) modulo MOD."""
15    return fact[n] * inv_fact[r] % MOD * inv_fact[n - r] % MOD
16
17
18class Solution:
19    def magicalSum(self, m: int, k: int, nums: List[int]) -> int:
20        n = len(nums)
21
22        @cache
23        def dfs(idx: int, remaining: int, bits_left: int, carry: int) -> int:
24            """
25            Count valid distributions.
26
27            idx        : current index into nums being processed
28            remaining  : number of selections still to assign among nums[idx:]
29            bits_left  : popcount budget still allowed in the final magical sum
30            carry      : accumulated carry value propagated from lower bit positions
31            """
32            # Pruning: budget exhausted, or all elements used but selections remain
33            if bits_left < 0 or (idx == n and remaining > 0):
34                return 0
35
36            # Base case: all elements processed, flush the remaining carry bits
37            if idx == n:
38                while carry:
39                    bits_left -= carry & 1
40                    carry >>= 1
41                return int(bits_left == 0)
42
43            total = 0
44            # Choose to assign `cnt` selections to the current element nums[idx]
45            for cnt in range(remaining + 1):
46                combined = cnt + carry                  # add current count to carry
47                power = pow(nums[idx], cnt, MOD)         # nums[idx]^cnt
48                next_bits = bits_left - (combined & 1)   # consume one bit by parity
49
50                # Multinomial factor C(remaining, cnt) times the contribution,
51                # then recurse with the higher carry bits (combined >> 1)
52                total += comb(remaining, cnt) * power % MOD * \
53                    dfs(idx + 1, remaining - cnt, next_bits, combined >> 1)
54                total %= MOD
55
56            return total
57
58        ans = dfs(0, m, k, 0)
59        dfs.cache_clear()  # clear memoization cache between test cases
60        return ans
61
1class Solution {
2    // Maximum size for factorial precomputation and bit-state dimension.
3    // N = 31 covers indices 0..30 (m can be up to 30).
4    static final int N = 31;
5
6    // Modulo value for all arithmetic to prevent overflow.
7    static final long MOD = 1_000_000_007L;
8
9    // factorial[i] = i! % MOD
10    private static final long[] factorial = new long[N];
11
12    // inverseFactorial[i] = modular multiplicative inverse of i! under MOD
13    private static final long[] inverseFactorial = new long[N];
14
15    // 4-dimensional memoization cache:
16    // memo[index][remaining][targetSetBits][carry]
17    private Long[][][][] memo;
18
19    // Static initializer: precompute factorials and their modular inverses.
20    static {
21        factorial[0] = 1;
22        inverseFactorial[0] = 1;
23        for (int i = 1; i < N; ++i) {
24            factorial[i] = factorial[i - 1] * i % MOD;
25            // Fermat's little theorem: inverse of x is x^(MOD-2) under prime MOD.
26            inverseFactorial[i] = qpow(factorial[i], MOD - 2);
27        }
28    }
29
30    /**
31     * Fast exponentiation: computes (base^exponent) % MOD.
32     *
33     * @param base     the base value
34     * @param exponent the exponent value
35     * @return base raised to exponent, taken modulo MOD
36     */
37    public static long qpow(long base, long exponent) {
38        long result = 1;
39        while (exponent != 0) {
40            // If the lowest bit is set, multiply current base into the result.
41            if ((exponent & 1) == 1) {
42                result = result * base % MOD;
43            }
44            // Square the base and shift to the next bit.
45            base = base * base % MOD;
46            exponent >>= 1;
47        }
48        return result;
49    }
50
51    /**
52     * Computes the binomial coefficient C(total, choose) modulo MOD,
53     * using precomputed factorials and inverse factorials.
54     *
55     * @param total  the total number of items (n in "n choose r")
56     * @param choose the number of items chosen (r in "n choose r")
57     * @return C(total, choose) % MOD
58     */
59    public static long comb(int total, int choose) {
60        return factorial[total]
61                * inverseFactorial[choose] % MOD
62                * inverseFactorial[total - choose] % MOD;
63    }
64
65    /**
66     * Counts the number of "magical" ways to distribute m selections across
67     * the elements of nums such that the resulting popcount condition equals k.
68     *
69     * @param m    the total number of selections to distribute
70     * @param k    the required number of set bits in the accumulated value
71     * @param nums the array of base values
72     * @return the count of valid arrangements modulo MOD
73     */
74    public int magicalSum(int m, int k, int[] nums) {
75        int n = nums.length;
76        // Dimensions: index (0..n), remaining (0..m), targetBits (0..k), carry (0..N-1).
77        memo = new Long[n + 1][m + 1][k + 1][N];
78        long ans = dfs(0, m, k, 0, nums);
79        return (int) ans;
80    }
81
82    /**
83     * Recursive DFS with memoization.
84     *
85     * @param index            current index into nums being processed
86     * @param remaining        number of selections still to distribute
87     * @param targetBits       remaining set-bit budget still needed
88     * @param carry            carry value propagated from lower bit positions
89     * @param nums             the array of base values
90     * @return the number of valid completions from this state, modulo MOD
91     */
92    private long dfs(int index, int remaining, int targetBits, int carry, int[] nums) {
93        // Prune: negative bit budget, or run out of array but still owe selections.
94        if (targetBits < 0 || (index == nums.length && remaining > 0)) {
95            return 0;
96        }
97
98        // Base case: all elements processed.
99        if (index == nums.length) {
100            // The leftover carry's set bits must exactly consume the remaining budget.
101            while (carry > 0) {
102                targetBits -= (carry & 1);
103                carry >>= 1;
104            }
105            return targetBits == 0 ? 1 : 0;
106        }
107
108        // Return cached result if this state was already computed.
109        if (memo[index][remaining][targetBits][carry] != null) {
110            return memo[index][remaining][targetBits][carry];
111        }
112
113        long result = 0;
114        // Try assigning t selections to the current element (0..remaining).
115        for (int t = 0; t <= remaining; t++) {
116            // Combine current selections with incoming carry.
117            int newCarry = t + carry;
118            // Consume one bit of budget if the current bit position is set.
119            int newTargetBits = targetBits - (newCarry & 1);
120            // Contribution factor: nums[index]^t.
121            long power = qpow(nums[index], t);
122            // Multiply by the number of ways to choose which t of the remaining are used,
123            // then recurse with the carry shifted to the next bit position.
124            long contribution = comb(remaining, t)
125                    * power % MOD
126                    * dfs(index + 1, remaining - t, newTargetBits, newCarry >> 1, nums) % MOD;
127            result = (result + contribution) % MOD;
128        }
129
130        // Cache and return.
131        return memo[index][remaining][targetBits][carry] = result;
132    }
133}
134
1class Solution {
2public:
3    // Public entry point. Method name kept as required.
4    int magicalSum(int m, int k, vector<int>& nums) {
5        this->nums = nums;
6        int n = static_cast<int>(nums.size());
7
8        // Precompute factorials and their modular inverses once.
9        ensureCombInit();
10
11        // dp[i][j][bitsLeft][state]:
12        //   i        -> index into nums currently being processed
13        //   j        -> remaining count of elements still to distribute
14        //   bitsLeft -> remaining number of set bits (popcount budget) we must place
15        //   state    -> accumulated carry coming from lower bit positions
16        // Initialized to -1 to mark "not computed yet".
17        dp.assign(
18            n + 1,
19            vector<vector<vector<long long>>>(
20                m + 1,
21                vector<vector<long long>>(
22                    k + 1,
23                    vector<long long>(kMaxN, -1))));
24
25        return static_cast<int>(dfs(0, m, k, 0));
26    }
27
28private:
29    // Maximum dimension for factorial tables and state size.
30    static const int kMaxN = 31;
31    static const long long kMod = 1'000'000'007;
32
33    // factorial[i] = i! mod kMod
34    // invFactorial[i] = (i!)^{-1} mod kMod
35    long long factorial[kMaxN];
36    long long invFactorial[kMaxN];
37    bool combReady = false;
38
39    vector<int> nums;
40
41    // Memoization table (see dimensions described in magicalSum).
42    vector<vector<vector<vector<long long>>>> dp;
43
44    // Fast modular exponentiation: returns (base^exp) mod kMod.
45    long long qpow(long long base, long long exp) {
46        long long result = 1;
47        base %= kMod;
48        while (exp) {
49            if (exp & 1) {
50                result = result * base % kMod;
51            }
52            base = base * base % kMod;
53            exp >>= 1;
54        }
55        return result;
56    }
57
58    // Lazily precompute factorials and modular inverse factorials.
59    void ensureCombInit() {
60        if (combReady) {
61            return;
62        }
63        factorial[0] = invFactorial[0] = 1;
64        for (int i = 1; i < kMaxN; ++i) {
65            factorial[i] = factorial[i - 1] * i % kMod;
66            // Fermat's little theorem: inverse = a^(MOD-2) mod MOD.
67            invFactorial[i] = qpow(factorial[i], kMod - 2);
68        }
69        combReady = true;
70    }
71
72    // Binomial coefficient C(total, choose) mod kMod.
73    long long comb(int total, int choose) {
74        return factorial[total] * invFactorial[choose] % kMod
75               * invFactorial[total - choose] % kMod;
76    }
77
78    // Recursive DP with memoization.
79    //   index    -> current position in nums
80    //   remain   -> remaining elements to distribute among nums[index..]
81    //   bitsLeft -> remaining popcount budget to satisfy
82    //   state    -> carry accumulated from lower bit positions
83    long long dfs(int index, int remain, int bitsLeft, int state) {
84        // Prune: overspent the popcount budget, or finished nums but still
85        // have leftover elements to place.
86        if (bitsLeft < 0 || (index == static_cast<int>(nums.size()) && remain > 0)) {
87            return 0;
88        }
89
90        // Reached the end with all elements placed: flush remaining carry bits
91        // into the popcount budget and check it lands exactly on zero.
92        if (index == static_cast<int>(nums.size())) {
93            while (state > 0) {
94                bitsLeft -= (state & 1);
95                state >>= 1;
96            }
97            return bitsLeft == 0 ? 1 : 0;
98        }
99
100        long long& result = dp[index][remain][bitsLeft][state];
101        if (result != -1) {
102            return result;
103        }
104
105        result = 0;
106        // Try assigning t copies to the current nums[index].
107        for (int t = 0; t <= remain; ++t) {
108            // Add the incoming carry to the count chosen at this position.
109            int combined = t + state;
110            // This bit contributes (combined & 1) to the popcount budget.
111            int nextBits = bitsLeft - (combined & 1);
112            // Contribution of choosing nums[index] exactly t times.
113            long long powVal = qpow(nums[index], t);
114            long long tmp = comb(remain, t) * powVal % kMod
115                            * dfs(index + 1, remain - t, nextBits, combined >> 1) % kMod;
116            result = (result + tmp) % kMod;
117        }
118        return result;
119    }
120};
121
1// Modulus used throughout the computation.
2const kMod = 1_000_000_007n;
3
4// Maximum dimension for factorial tables and the state (carry) size.
5const kMaxN = 31;
6
7// factorial[i]    = i! mod kMod
8// invFactorial[i] = (i!)^{-1} mod kMod
9const factorial: bigint[] = new Array(kMaxN).fill(0n);
10const invFactorial: bigint[] = new Array(kMaxN).fill(0n);
11
12// Flag to ensure factorial tables are computed only once.
13let combReady = false;
14
15// Shared input array (assigned at the start of magicalSum).
16let nums: number[] = [];
17
18// Memoization table:
19//   dp[index][remain][bitsLeft][state]
20// Initialized lazily inside magicalSum. A value of -1n means "not computed yet".
21let dp: bigint[][][][] = [];
22
23/**
24 * Fast modular exponentiation: returns (base^exp) mod kMod.
25 */
26function qpow(base: bigint, exp: bigint): bigint {
27  let result = 1n;
28  base %= kMod;
29  while (exp > 0n) {
30    if (exp & 1n) {
31      result = (result * base) % kMod;
32    }
33    base = (base * base) % kMod;
34    exp >>= 1n;
35  }
36  return result;
37}
38
39/**
40 * Lazily precompute factorials and modular inverse factorials.
41 */
42function ensureCombInit(): void {
43  if (combReady) {
44    return;
45  }
46  factorial[0] = 1n;
47  invFactorial[0] = 1n;
48  for (let i = 1; i < kMaxN; ++i) {
49    factorial[i] = (factorial[i - 1] * BigInt(i)) % kMod;
50    // Fermat's little theorem: inverse = a^(MOD-2) mod MOD.
51    invFactorial[i] = qpow(factorial[i], kMod - 2n);
52  }
53  combReady = true;
54}
55
56/**
57 * Binomial coefficient C(total, choose) mod kMod.
58 */
59function comb(total: number, choose: number): bigint {
60  return (
61    ((factorial[total] * invFactorial[choose]) % kMod) *
62    invFactorial[total - choose]
63  ) % kMod;
64}
65
66/**
67 * Recursive DP with memoization.
68 *   index    -> current position in nums
69 *   remain   -> remaining elements to distribute among nums[index..]
70 *   bitsLeft -> remaining popcount budget to satisfy
71 *   state    -> carry accumulated from lower bit positions
72 */
73function dfs(index: number, remain: number, bitsLeft: number, state: number): bigint {
74  // Prune: overspent the popcount budget, or finished nums but still
75  // have leftover elements to place.
76  if (bitsLeft < 0 || (index === nums.length && remain > 0)) {
77    return 0n;
78  }
79
80  // Reached the end with all elements placed: flush remaining carry bits
81  // into the popcount budget and check it lands exactly on zero.
82  if (index === nums.length) {
83    while (state > 0) {
84      bitsLeft -= state & 1;
85      state >>= 1;
86    }
87    return bitsLeft === 0 ? 1n : 0n;
88  }
89
90  // Return cached value if present.
91  const cached = dp[index][remain][bitsLeft][state];
92  if (cached !== -1n) {
93    return cached;
94  }
95
96  let result = 0n;
97  // Try assigning t copies to the current nums[index].
98  for (let t = 0; t <= remain; ++t) {
99    // Add the incoming carry to the count chosen at this position.
100    const combined = t + state;
101    // This bit contributes (combined & 1) to the popcount budget.
102    const nextBits = bitsLeft - (combined & 1);
103    // Contribution of choosing nums[index] exactly t times.
104    const powVal = qpow(BigInt(nums[index]), BigInt(t));
105    const tmp =
106      (((comb(remain, t) * powVal) % kMod) *
107        dfs(index + 1, remain - t, nextBits, combined >> 1)) %
108      kMod;
109    result = (result + tmp) % kMod;
110  }
111
112  dp[index][remain][bitsLeft][state] = result;
113  return result;
114}
115
116/**
117 * Public entry point. Method name kept as required.
118 */
119function magicalSum(m: number, k: number, nums_: number[]): number {
120  nums = nums_;
121  const n = nums.length;
122
123  // Precompute factorials and their modular inverses once.
124  ensureCombInit();
125
126  // Initialize dp[i][j][bitsLeft][state] to -1n ("not computed yet"):
127  //   i        -> index into nums currently being processed
128  //   j        -> remaining count of elements still to distribute
129  //   bitsLeft -> remaining number of set bits (popcount budget) we must place
130  //   state    -> accumulated carry coming from lower bit positions
131  dp = Array.from({ length: n + 1 }, () =>
132    Array.from({ length: m + 1 }, () =>
133      Array.from({ length: k + 1 }, () =>
134        new Array<bigint>(kMaxN).fill(-1n),
135      ),
136    ),
137  );
138
139  return Number(dfs(0, m, k, 0));
140}
141

Time and Space Complexity

Time Complexity: O(n × m³ × k)

The core of the algorithm is the recursive function dfs(i, j, k, st), whose complexity is determined by the number of distinct states multiplied by the work done per state.

  • State analysis:

    • i ranges over the array indices, giving O(n) values, where n is the length of nums.
    • j represents the remaining count to distribute, ranging from 0 to m, giving O(m) values.
    • k represents the remaining required popcount, giving O(k) values.
    • st is the carry accumulated from lower bits. Since at each step at most m items contribute and carries are halved (nt >> 1), the value of st is bounded by O(m), giving O(m) values.

    Thus the total number of states is O(n × m × k × m) = O(n × m² × k).

  • Work per state:

    • Inside each call, the loop for t in range(j + 1) iterates up to m + 1 times, contributing a factor of O(m).
    • The pow(nums[i], t, mod) and comb(j, t) are effectively O(1) (with precomputed factorials, and t bounded by a small constant mx = 30).

    So each state does O(m) work.

  • Total: combining states and per-state work yields O(n × m² × k) × O(m) = O(n × m³ × k).

Space Complexity: O(n × m² × k)

The space is dominated by the memoization cache (@cache), which stores one entry for each distinct state (i, j, k, st). As analyzed above, the number of distinct states is O(n × m² × k). The recursion depth is only O(n), and the precomputed factorial arrays f and g use O(mx) space, both of which are dominated by the cache. Hence the overall space complexity is O(n × m² × k).

Pattern Learn more about how to find time and space complexity quickly.

Common Pitfalls

Pitfall 1: Forgetting to Flush the Final Carry's Set Bits

The most frequent and damaging mistake is treating the popcount condition as if it only concerns the bits produced while iterating over nums[0..n-1]. When you finish processing the last index (idx == n), there can still be a leftover carry representing bits at positions >= n that have not yet been counted toward bits_left.

Buggy version:

if idx == n:
    # WRONG: ignores any remaining carry bits above position n
    return int(bits_left == 0)

Consider m = 2, nums.length = 1. Choosing index 0 twice gives 2^0 + 2^0 = 2 = (10)_2, which has one set bit at position 1. While processing idx = 0, the parity combined & 1 = 0 consumes no bit, and the carry 1 flows out. If you stop at idx == n without flushing, that genuine set bit is never counted, and you wrongly reject (or accept) sequences.

Fix: Drain every remaining bit of carry before deciding feasibility.

if idx == n:
    while carry:
        bits_left -= carry & 1
        carry >>= 1
    return int(bits_left == 0)

This is precisely why bits_left < 0 must also be checked after the flush — the carry can push the budget negative.


Pitfall 2: Confusing "Set Bits" with "Total Count of Selections"

A natural but incorrect intuition is to think the answer relates to selecting exactly k distinct indices, ignoring that carries merge powers of two. Because 2^i + 2^i = 2^{i+1}, picking the same index multiple times does not add multiple set bits — it propagates a carry.

# WRONG mental model: "k distinct indices chosen"
next_bits = bits_left - (1 if cnt > 0 else 0)

The correct rule consumes a bit based on the parity of cnt + carry, not on whether the index was used at all:

combined = cnt + carry
next_bits = bits_left - (combined & 1)   # parity determines the resulting bit
carry_out = combined >> 1                # everything above parity carries forward

Missing this leads to silently wrong counts on any input where some index is reused (which is allowed and very common).


Pitfall 3: Forgetting the Multinomial (Ordering) Factor

Since two sequences with the same multiset of indices in different orders are distinct, you must multiply by C(remaining, cnt) — the number of ways to choose which of the remaining positions take value idx.

# WRONG: counts each multiset once, undercounting ordered sequences
total += power * dfs(idx + 1, remaining - cnt, next_bits, combined >> 1)

# CORRECT: account for which positions receive index idx
total += comb(remaining, cnt) * power % MOD * \
         dfs(idx + 1, remaining - cnt, next_bits, combined >> 1)

Telescoping C(m, t0) * C(m - t0, t1) * ... correctly reproduces the full multinomial coefficient m! / (t0! t1! ... ), which is the number of orderings.


Pitfall 4: Carry Range and Cache Explosion

The carry can grow up to roughly m (at most m powers summed at one position before shifting), but a common slip is bounding it too tightly (e.g., assuming it never exceeds 1) or, conversely, failing to clear the cache between test cases.

ans = dfs(0, m, k, 0)
dfs.cache_clear()   # essential: stale states across calls corrupt later answers
return ans

When @cache decorates a closure that captures nums, leftover entries from a previous instance can leak if the function object is reused. Clearing the cache (or rebuilding dfs per call, as done here) prevents subtle cross-test contamination.


Pitfall 5: Precomputed Table Too Small

The factorial tables are sized by mx = 30. Binomial coefficients C(remaining, cnt) require indices up to m. If m ever exceeds mx, comb reads beyond the initialized range (where fact[i] == 0), producing zeros and silently wrong answers.

Fix: Size the tables to at least m + 1 (or max(m, k) + 1):

mx = max(31, m + 1)   # ensure fact/inv_fact cover all needed indices
fact = [1] * mx
inv_fact = [1] * mx
for i in range(1, mx):
    fact[i] = fact[i - 1] * i % MOD
inv_fact[mx - 1] = pow(fact[mx - 1], MOD - 2, MOD)
for i in range(mx - 1, 0, -1):
    inv_fact[i - 1] = inv_fact[i] * i % MOD   # O(n) inverse, avoids per-element pow

The backward recurrence for inverse factorials also removes the per-element pow call, turning the precompute from O(mx log MOD) into O(mx).

Ready to land your dream job?

Unlock your dream job with a 5-minute quiz for a personalized study roadmap!

Get My Roadmap
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Get a Personalized Study Roadmap:

Which technique can we use to find the middle of a linked list?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More