973. K Closest Points to Origin


Problem Description

The problem provides us with an array of points called points, where each point is itself an array containing an x (x_i) and y (y_i) coordinate. These points are located on a two-dimensional plane. Additionally, we're given an integer k. Our task is to find the k points from the array that are closest to the origin point (0, 0) based on the Euclidean distance. The Euclidean distance between two points (x1, y1) and (x2, y2) on the X-Y plane is the square root of (x1-x2)^2 + (y1-y2)^2. The solution should return these k closest points in any order, and it is noted that the points returned will be unique in terms of their position in the output list.

Intuition

To solve this problem, we need to determine how far each point in the array is from the origin. In geometrical terms, we are interested in the magnitude of the vector from the origin to the point, which corresponds to the Euclidean distance. However, since we only care about the relative distances for sorting (finding the smallest distances) and not the exact distances themselves, we can avoid calculating the square root to simplify the comparison. By comparing the squares of the distances, we can maintain the same ordering as we would if we used the actual distances.

Here is the thought process behind the solution:

  1. We recognize that we need to sort the points by their distance to the origin.
  2. To avoid unnecessary computation, we can compare the squared distances instead of the actual distances.
  3. We define a custom sorting key that calculates x^2 + y^2 for each point, where x and y are the coordinates of that point.
  4. Using the sort() function of a list in Python, we sort all points based on their squared distances from the origin.
  5. Finally, we return the first k elements of the sorted array, as these will be the k closest points to the origin.

By applying this approach, we achieve a manageable solution that is easy to implement and understand, using built-in Python functionalities for sorting with a custom key.

Learn more about Math, Divide and Conquer, Sorting and Heap (Priority Queue) patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece:

Which of the following problems can be solved with backtracking (select multiple)

Solution Approach

The implementation of the solution makes use of Python's built-in sort function and the concept of lambda functions for the custom sorting key.

  1. The lambda function is employed as the sorting key for the sort() method. This function takes a point p as input and returns p[0] * p[0] + p[1] * p[1]. This expression equates to the square of the distance from the point to the origin, omitting the square root for the reasons described earlier.

  2. The algorithm:

    • Accepts the list of points and the integer k.
    • Applies the sort() function on the list of points. Instead of sorting by a single element, it sorts by the value returned by the lambda function for each element, which represents the squared distance of each point to the origin.
    • The sorting process rearranges the items of the list in place, meaning that after the sort() method is called, the original list is modified to be in sorted order.
  3. After sorting the full list, a slice of the first k elements (points[:k]) is returned. In Python, slicing a list returns a new list containing the specified range of elements, so in this case, it gives us the closest k points.

  4. The use of the sort() function and a slicing operation makes the solution concise and effective. The Python sort() method is an efficient, typically O(n log n) algorithm for sorting lists.

  5. The space complexity is kept to O(1) since the sorting is done in place and only the resulting list slice of size k is returned.

In conclusion, the solution leverages the efficiency of Python's sorting algorithm and a concise expression of the distance calculation to produce a clean and efficient algorithm for finding the k closest points to the origin.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

What's the output of running the following function using input 56?

1KEYBOARD = {
2    '2': 'abc',
3    '3': 'def',
4    '4': 'ghi',
5    '5': 'jkl',
6    '6': 'mno',
7    '7': 'pqrs',
8    '8': 'tuv',
9    '9': 'wxyz',
10}
11
12def letter_combinations_of_phone_number(digits):
13    def dfs(path, res):
14        if len(path) == len(digits):
15            res.append(''.join(path))
16            return
17
18        next_number = digits[len(path)]
19        for letter in KEYBOARD[next_number]:
20            path.append(letter)
21            dfs(path, res)
22            path.pop()
23
24    res = []
25    dfs([], res)
26    return res
27
1private static final Map<Character, char[]> KEYBOARD = Map.of(
2    '2', "abc".toCharArray(),
3    '3', "def".toCharArray(),
4    '4', "ghi".toCharArray(),
5    '5', "jkl".toCharArray(),
6    '6', "mno".toCharArray(),
7    '7', "pqrs".toCharArray(),
8    '8', "tuv".toCharArray(),
9    '9', "wxyz".toCharArray()
10);
11
12public static List<String> letterCombinationsOfPhoneNumber(String digits) {
13    List<String> res = new ArrayList<>();
14    dfs(new StringBuilder(), res, digits.toCharArray());
15    return res;
16}
17
18private static void dfs(StringBuilder path, List<String> res, char[] digits) {
19    if (path.length() == digits.length) {
20        res.add(path.toString());
21        return;
22    }
23    char next_digit = digits[path.length()];
24    for (char letter : KEYBOARD.get(next_digit)) {
25        path.append(letter);
26        dfs(path, res, digits);
27        path.deleteCharAt(path.length() - 1);
28    }
29}
30
1const KEYBOARD = {
2    '2': 'abc',
3    '3': 'def',
4    '4': 'ghi',
5    '5': 'jkl',
6    '6': 'mno',
7    '7': 'pqrs',
8    '8': 'tuv',
9    '9': 'wxyz',
10}
11
12function letter_combinations_of_phone_number(digits) {
13    let res = [];
14    dfs(digits, [], res);
15    return res;
16}
17
18function dfs(digits, path, res) {
19    if (path.length === digits.length) {
20        res.push(path.join(''));
21        return;
22    }
23    let next_number = digits.charAt(path.length);
24    for (let letter of KEYBOARD[next_number]) {
25        path.push(letter);
26        dfs(digits, path, res);
27        path.pop();
28    }
29}
30

Example Walkthrough

Let's illustrate the solution approach with a simple example. Imagine we have the following points and we are tasked with finding the k = 2 closest points to the origin (0, 0):

1points = [[1, 3], [2, -2], [5, 4], [-3, 3]]

Following the thought process behind the solution:

  1. We calculate the squared distance from the origin for each point which gives us:

    • For point [1, 3]: 1^2 + 3^2 = 1 + 9 = 10
    • For point [2, -2]: 2^2 + (-2)^2 = 4 + 4 = 8
    • For point [5, 4]: 5^2 + 4^2 = 25 + 16 = 41
    • For point [-3, 3]: (-3)^2 + 3^2 = 9 + 9 = 18
  2. We sort the points based on their squared distances, without computing the square root:

    1Sorted points based on squared distances: [[2, -2], [1, 3], [-3, 3], [5, 4]]
    2Sorted squared distances: [8, 10, 18, 41]
  3. We employ a lambda function in Python to serve as a custom sorting key: lambda p: p[0]*p[0] + p[1]*p[1].

  4. We apply this lambda function within the sort() method to our points list to rearrange the list based on the calculated squared distances:

    1points.sort(key=lambda p: p[0]*p[0] + p[1]*p[1])
  5. After sorting, points looks like this: [[2, -2], [1, 3], [-3, 3], [5, 4]].

  6. We return the first k elements of the sorted array to get the k closest points. Since k = 2, we return the first two points:

    1closest_points = points[:k]  # closest_points = [[2, -2], [1, 3]]

With this example, we have effectively walked through the solution approach described in the problem content. We sorted the points by their squared distance from the origin, avoided unnecessary square root computation, and returned the k closest points by slicing the sorted array without altering the original array's order.

Solution Implementation

1from typing import List
2
3class Solution:
4    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
5        # Sort the given list of points by their Euclidean distance
6        # from the origin (0, 0) without actually computing the square root.
7        # The lambda function computes the squared distance.
8        points.sort(key=lambda point: point[0]**2 + point[1]**2)
9      
10        # Return the first k points from the sorted list.
11        return points[:k]
12
1import java.util.Arrays;
2
3class Solution {
4    // This method finds the k closest points to the origin
5    public int[][] kClosest(int[][] points, int k) {
6        // Sort the array of points based on their Euclidean distance from the origin
7        Arrays.sort(points, (point1, point2) -> {
8            // Calculate the squared distance for the first point from the origin
9            int distance1 = point1[0] * point1[0] + point1[1] * point1[1];
10            // Calculate the squared distance for the second point from the origin
11            int distance2 = point2[0] * point2[0] + point2[1] * point2[1];
12            // Compare the two distances
13            return distance1 - distance2;
14        });
15
16        // Return the first k elements of the sorted array, which are the k closest to the origin
17        return Arrays.copyOfRange(points, 0, k);
18    }
19}
20
1#include <vector>
2#include <algorithm>
3
4class Solution {
5public:
6    // Method to find the 'k' closest points to the origin (0, 0) in a 2D plane
7    std::vector<std::vector<int>> kClosest(std::vector<std::vector<int>>& points, int k) {
8        // Sort the 'points' array using a custom comparator
9        std::sort(points.begin(), points.end(), [](const std::vector<int>& pointA, const std::vector<int>& pointB) {
10            // Calculate the squared Euclidean distance from the origin for pointA
11            int distA = pointA[0] * pointA[0] + pointA[1] * pointA[1];
12            // Calculate the squared Euclidean distance from the origin for pointB
13            int distB = pointB[0] * pointB[0] + pointB[1] * pointB[1];
14            // Compare the squared distances to sort the points by distance
15            return distA < distB;
16        });
17
18        // Construct a vector containing the first 'k' elements from the sorted 'points' array
19        return std::vector<std::vector<int>>(points.begin(), points.begin() + k);
20    }
21};
22
1function kClosest(points: number[][], k: number): number[][] {
2    // A comparison function for the .sort() method, which will sort the points
3    // based on their distance squared from the origin in ascending order.
4    const compareDistance = (pointA: number[], pointB: number[]) => {
5        // Calculate the distance squared of pointA from the origin.
6        const distanceA = pointA[0] ** 2 + pointA[1] ** 2;
7        // Calculate the distance squared of pointB from the origin.
8        const distanceB = pointB[0] ** 2 + pointB[1] ** 2;
9        // The return value determines the order of sorting.
10        return distanceA - distanceB;
11    };
12
13    // Sort the array of points with the custom compare function.
14    const sortedPoints = points.sort(compareDistance);
15
16    // Return the first 'k' elements from the sorted array of points.
17    return sortedPoints.slice(0, k);
18}
19
Not Sure What to Study? Take the 2-min Quiz:

In a binary min heap, the maximum element can be found in:

Time and Space Complexity

Time Complexity: The time complexity of the code is O(n log n), where n is the number of points. This arises from the use of the .sort() method, which has O(n log n) complexity for sorting the list of points based on their distance from the origin calculated by the key function (lambda p: p[0] * p[0] + p[1] * p[1]).

Space Complexity: The space complexity of the code is O(n). While the sorting is done in-place and does not require additional space proportional to the input, the sorted list of points is stored in the same space that was taken by the input list. Hence, the space complexity is linear with respect to the size of the input list of points.

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

Which of the following is equvalent to O(3*2^n + n^3 + n!+ log n)?


Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫