Count of Smaller Numbers after Self | Number of Swaps to Sort | Algorithm Swap

You are given an integer array nums and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of nums[i].

Examples:

Example 1:

Input: [5,2,6,1]

Output: [2,1,1,0]

Explanation:

For the number 5, there are 2 numbers smaller than it after it. (2 and 1)

For the number 2, there is 1 number smaller than it after it. (1)

For the number 6, there is also 1 number smaller than it after it. (1)

For the number 1, there are no numbers smaller than it after it.

Hence, we have [2, 1, 1, 0].

Number of swaps to sort

Another version of the question is:

If we sort the nums array by finding the smallest pair i, j where i < j and nums[i] > nums[j], how many swaps are needed?

To answer that question, we just have to sum up the numbers in the above output array: 2 + 1 + 1 = 4 swaps.

Try it yourself

Explanation

Intuition

The brute force way to solve this question is really easy and intuitive, we simply go through the list of elements. For each of the element, we go through the elements after it and count how many numbers are smaller than it. This would result in a O(N^2) runtime. However, this approach is not the optimal solution.

Observe that if we need to reduce our solution's complexity, we will need to count multiple numbers' smaller count in one go. This can only be done using some kind of sorted order.

But sorting destroys the origin order of the array, what can we do about that?

Recall from introduction of divide and conquer questions, the common approach of tackling a divide and conquer question is dividing the data given into two components, assuming each components is solved and then try to merge the result.

What if we divide the numbers into two components by index and then sort them separate?

Since we divided the original array by index, after the two components are both sorted, all the elements in the left components still have smaller index than any element in the right components in the original array.

We can utilize this fact when we combine the two arrays together.

Thus, to solve this problem, we first split the data given into two components, the left and the right components. And then we assume that both components' sub-problem are already solved -- that is we know the count of number smaller than itself for each number for both components. Now all we need to know is for each number in the left component, how many elements are smaller than it in the right component.

This will allow us to know for each number in the left components, how many elements is smaller than it in the right component.

Thus, we have successfully solved the problem.

So, what is the run time of our improved solution? We split the problem into two components each recursion and go through each of the components, and each recursion takes O(N) time for the merge process. Thus we have

1T(N) = 2T(N/2) + O(N)

This recurrence will yield a total run time of O(N log N).

Implementation

1from typing import List
2
3def count_smaller(nums: List[int]) -> List[int]:
4    smaller_arr = [0] * len(nums)
5    def merge_sort(nums):
6        if len(nums) <= 1:
7            return nums
8        mid = len(nums) // 2
9        left = merge_sort(nums[:mid])
10        right = merge_sort(nums[mid:])
11        return merge(left, right)
12
13    def merge(left, right):
14        result = []
15        l, r = 0, 0
16        while l < len(left) or r < len(right):
17            if r >= len(right) or (l < len(left) and left[l][1] <= right[r][1]):
18                result.append(left[l])
19                smaller_arr[left[l][0]] += r
20                l += 1
21            else:
22                result.append(right[r])
23                r += 1
24        return result
25
26    merge_sort(list(enumerate(nums)))
27    return smaller_arr
28
29if __name__ == '__main__':
30    nums = [int(x) for x in input().split()]
31    res = count_smaller(nums)
32    print(' '.join(map(str, res)))
33
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.Scanner;
5import java.util.stream.Collectors;
6
7class Solution {
8    public static class Element{
9        int val, ind;
10        public Element(int val, int ind) {
11            this.val = val;
12            this.ind = ind;
13        }
14    }
15    public static List<Integer> smallerArr = new ArrayList<Integer>();
16    public static List<Element> mergeSort(List<Element> nums) {
17        if (nums.size() <= 1) {
18            return nums;
19        }
20        int mid = nums.size() / 2;
21        List<Element> splitLeft = new ArrayList<Element>();
22        List<Element> splitRight = new ArrayList<Element>();
23        for (int i = 0; i < nums.size(); i++) {
24            if (i < nums.size() / 2) splitLeft.add(nums.get(i));
25            else splitRight.add(nums.get(i));
26        }
27        List<Element> left = mergeSort(splitLeft);
28        List<Element> right = mergeSort(splitRight);
29        return merge(left, right);
30    }
31    public static List<Element> merge(List<Element> left, List<Element> right) {
32        List<Element> result = new ArrayList<Element>();
33        int l = 0;
34        int r = 0;
35        while (l < left.size() || r < right.size()) {
36            if (r >= right.size() || (l < left.size() && left.get(l).val <= right.get(r).val)) {
37                result.add(left.get(l));
38                smallerArr.set(left.get(l).ind, smallerArr.get(left.get(l).ind) + r);
39                l += 1;
40            }
41            else {
42                result.add(right.get(r));
43                r += 1;
44            }
45        }
46        return result;
47    }
48    public static List<Integer> countSmaller(List<Integer> nums) {
49        for (int i = 0; i < nums.size(); i++) smallerArr.add(0);
50        List<Element> temp = new ArrayList<Element>();
51        for (int i = 0; i < nums.size(); i++) temp.add(new Element(nums.get(i), i));
52        mergeSort(temp);
53        return smallerArr;
54    }
55
56    public static List<String> splitWords(String s) {
57        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
58    }
59
60    public static void main(String[] args) {
61        Scanner scanner = new Scanner(System.in);
62        List<Integer> nums = splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList());
63        scanner.close();
64        List<Integer> res = countSmaller(nums);
65        System.out.println(res.stream().map(String::valueOf).collect(Collectors.joining(" ")));
66    }
67}
68
1function countSmaller(nums) {
2    const smallerArr = Array(nums.length).fill(0);
3
4    function mergeSort(nums) {
5        if (nums.length <= 1) return nums;
6        const mid = Math.floor(nums.length / 2);
7        const left = mergeSort(nums.slice(0, mid));
8        const right = mergeSort(nums.slice(mid));
9        return merge(left, right);
10    }
11
12    function merge(left, right) {
13        const result = [];
14        let l = 0, r = 0;
15        while (l < left.length || r < right.length) {
16            if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
17                result.push(left[l]);
18                smallerArr[left[l][0]] += r;
19                l += 1;
20            } else {
21                result.push(right[r]);
22                r += 1;
23            }
24        }
25        return result;
26    }
27    const temp = [];
28    nums.map((e,i) => temp.push([i, e]));
29    mergeSort(temp);
30    return smallerArr;
31}
32
33function splitWords(s) {
34    return s == "" ? [] : s.split(' ');
35}
36
37function* main() {
38    const nums = splitWords(yield).map((v) => parseInt(v));
39    const res = countSmaller(nums);
40    console.log(res.join(' '));
41}
42
43class EOFError extends Error {}
44{
45    const gen = main();
46    const next = (line) => gen.next(line).done && process.exit();
47    let buf = '';
48    next();
49    process.stdin.setEncoding('utf8');
50    process.stdin.on('data', (data) => {
51        const lines = (buf + data).split('\n');
52        buf = lines.pop();
53        lines.forEach(next);
54    });
55    process.stdin.on('end', () => {
56        buf && next(buf);
57        gen.throw(new EOFError());
58    });
59}
60
1#include <algorithm> // copy
2#include <iostream> // cin, cout
3#include <iterator> // back_inserter, istream_iterator, ostream_iterator, prev
4#include <sstream> // istringstream
5#include <string> // getline, string
6#include <vector> // vector
7
8std::vector<std::vector<int>> merge(std::vector<std::vector<int>> left, std::vector<std::vector<int>> right, std::vector<int>& counts) {
9    std::vector<std::vector<int>> res;
10    int l = 0, r = 0;
11    while (l < left.size() || r < right.size()) {
12        if (r >= right.size() || (l < left.size() && left[l][1] <= right[r][1])) {
13            res.emplace_back(left[l]);
14            counts[left[l][0]] = counts[left[l][0]] + r;
15            l++;
16        } else {
17            res.emplace_back(right[r]);
18            r++;
19        }
20    }
21    return res;
22}
23
24std::vector<std::vector<int>> merge_sort(std::vector<std::vector<int>> nums, std::vector<int>& counts) {
25    if (nums.size() <= 1) return nums;
26    int mid = nums.size() / 2;
27    std::vector<std::vector<int>> split_left(nums.begin(), nums.begin() + mid);
28    std::vector<std::vector<int>> split_right(nums.begin() + mid, nums.end());
29    std::vector<std::vector<int>> left = merge_sort(split_left, counts);
30    std::vector<std::vector<int>> right = merge_sort(split_right, counts);
31    return merge(left, right, counts);
32}
33
34std::vector<int> count_smaller(std::vector<int> nums) {
35    std::vector<int> counts(nums.size(), 0);
36    std::vector<std::vector<int>> idx_num_mapping;
37    for (int i = 0; i < nums.size(); i++) {
38        std::vector<int> idx_num_pair{ i, nums[i] };
39        idx_num_mapping.emplace_back(idx_num_pair);
40    }
41    merge_sort(idx_num_mapping, counts);
42    return counts;
43}
44
45template<typename T>
46std::vector<T> get_words() {
47    std::string line;
48    std::getline(std::cin, line);
49    std::istringstream ss{line};
50    std::vector<T> v;
51    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
52    return v;
53}
54
55template<typename T>
56void put_words(const std::vector<T>& v) {
57    if (!v.empty()) {
58        std::copy(v.begin(), std::prev(v.end()), std::ostream_iterator<T>{std::cout, " "});
59        std::cout << v.back();
60    }
61    std::cout << '\n';
62}
63
64int main() {
65    std::vector<int> nums = get_words<int>();
66    std::vector<int> res = count_smaller(nums);
67    put_words(res);
68}
69

If the problem asks for number of swaps, we can simple keep a counter each time we swap and don't have to keep the array.

1from typing import List
2
3def number_of_swaps_to_sort(nums: List[int]) -> int:
4    count = 0
5    def merge_sort(nums):
6        if len(nums) <= 1:
7            return nums
8        mid = len(nums) // 2
9        left = merge_sort(nums[:mid])
10        right = merge_sort(nums[mid:])
11        return merge(left, right)
12    def merge(left, right):
13        nonlocal count
14        result = []
15        l, r = 0, 0
16        while l < len(left) or r < len(right):
17            if r >= len(right) or (l < len(left) and left[l][1] <= right[r][1]):
18                result.append(left[l])
19                count += r
20                l += 1
21            else:
22                result.append(right[r])
23                r += 1
24        return result
25    merge_sort(list(enumerate(nums)))
26    return count
27
28if __name__ == '__main__':
29    nums = [int(x) for x in input().split()]
30    res = number_of_swaps_to_sort(nums)
31    print(res)
32
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.Scanner;
5import java.util.stream.Collectors;
6
7class Solution {
8    public static class Number {
9        int index;
10        int val;
11        public Number(int i, int v) {
12            index = i;
13            val = v;
14        }
15    };
16
17    private static int count;
18
19    private static List<Number> mergeSort(List<Number> nums) {
20        if (nums.size() <= 1) {
21            return nums;
22        }
23        int mid = nums.size() / 2;
24        List<Number> left = mergeSort(nums.subList(0, mid));
25        List<Number> right = mergeSort(nums.subList(mid, nums.size()));
26        return merge(left, right);
27    }
28
29    private static List<Number> merge(List<Number> left, List<Number> right) {
30        List<Number> result = new ArrayList<>();
31        int l = 0;
32        int r = 0;
33        while (l < left.size() || r < right.size()) {
34            if (r >= right.size() || (l < left.size() && left.get(l).val <= right.get(r).val)) {
35                result.add(left.get(l));
36                count += r;
37                l++;
38            } else {
39                result.add(right.get(r));
40                r++;
41            }
42        }
43        return result;
44    }
45
46    public static int numberOfSwapsToSort(List<Integer> nums) {
47        List<Number> numbers = new ArrayList<>();
48        for (int i = 0; i < nums.size(); i++) {
49            numbers.add(new Number(i, nums.get(i)));
50        }
51
52        mergeSort(numbers);
53        return count;
54    }
55
56    public static List<String> splitWords(String s) {
57        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
58    }
59
60    public static void main(String[] args) {
61        Scanner scanner = new Scanner(System.in);
62        List<Integer> nums = splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList());
63        scanner.close();
64        int res = numberOfSwapsToSort(nums);
65        System.out.println(res);
66    }
67}
68
1function numberOfSwapsToSort(nums) {
2    let count = 0;
3
4    function mergeSort(nums) {
5        if (nums.length <= 1) return nums;
6        const mid = Math.floor(nums.length / 2);
7        const left = mergeSort(nums.slice(0, mid));
8        const right = mergeSort(nums.slice(mid));
9        return merge(left, right);
10    }
11
12    function merge(left, right) {
13        const result = [];
14        let l = 0, r = 0;
15        while (l < left.length || r < right.length) {
16            if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
17                result.push(left[l]);
18                count += r;
19                l += 1;
20            } else {
21                result.push(right[r]);
22                r += 1;
23            }
24        }
25        return result;
26    }
27    const temp = [];
28    nums.map((e,i) => temp.push([i, e]));
29    mergeSort(temp);
30    return count;
31}
32
33function splitWords(s) {
34    return s == "" ? [] : s.split(' ');
35}
36
37function* main() {
38    const nums = splitWords(yield).map((v) => parseInt(v));
39    const res = numberOfSwapsToSort(nums);
40    console.log(res);
41}
42
43class EOFError extends Error {}
44{
45    const gen = main();
46    const next = (line) => gen.next(line).done && process.exit();
47    let buf = '';
48    next();
49    process.stdin.setEncoding('utf8');
50    process.stdin.on('data', (data) => {
51        const lines = (buf + data).split('\n');
52        buf = lines.pop();
53        lines.forEach(next);
54    });
55    process.stdin.on('end', () => {
56        buf && next(buf);
57        gen.throw(new EOFError());
58    });
59}
60
1#include <algorithm> // copy
2#include <iostream> // cin, cout
3#include <iterator> // back_inserter, istream_iterator
4#include <sstream> // istringstream
5#include <string> // getline, string
6#include <vector> // vector
7
8std::vector<std::vector<int>> merge(std::vector<std::vector<int>> left, std::vector<std::vector<int>> right, int& count) {
9    std::vector<std::vector<int>> res;
10    int l = 0, r = 0;
11    while (l < left.size() || r < right.size()) {
12        if (r >= right.size() || (l < left.size() && left[l][1] <= right[r][1])) {
13            res.emplace_back(left[l]);
14            count += r;
15            l++;
16        } else {
17            res.emplace_back(right[r]);
18            r++;
19        }
20    }
21    return res;
22}
23
24std::vector<std::vector<int>> merge_sort(std::vector<std::vector<int>> nums, int& count) {
25    if (nums.size() <= 1) return nums;
26    int mid = nums.size() / 2;
27    std::vector<std::vector<int>> split_left(nums.begin(), nums.begin() + mid);
28    std::vector<std::vector<int>> split_right(nums.begin() + mid, nums.end());
29    std::vector<std::vector<int>> left = merge_sort(split_left, count);
30    std::vector<std::vector<int>> right = merge_sort(split_right, count);
31    return merge(left, right, count);
32}
33
34int number_of_swaps_to_sort(std::vector<int> nums) {
35    int count = 0;
36    std::vector<std::vector<int>> idx_num_mapping;
37    for (int i = 0; i < nums.size(); i++) {
38        std::vector<int> idx_num_pair{ i, nums[i] };
39        idx_num_mapping.emplace_back(idx_num_pair);
40    }
41    merge_sort(idx_num_mapping, count);
42    return count;
43}
44
45template<typename T>
46std::vector<T> get_words() {
47    std::string line;
48    std::getline(std::cin, line);
49    std::istringstream ss{line};
50    std::vector<T> v;
51    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
52    return v;
53}
54
55int main() {
56    std::vector<int> nums = get_words<int>();
57    int res = number_of_swaps_to_sort(nums);
58    std::cout << res << '\n';
59}
60