< Back





973. K Closest Points to Origin

We're asked to return an array of the k closest points to the origin (0, 0). We could use a heap to
solve this, however, another solution is to use quickselect with an average time complexity of O(n).
Like previous quickselect questions, we create our comparison function, this being math.dist for a
point vs the origin - we use a lambda.

During quickselect, we select a random point in the list of points, this will be our first pivot. We
proceed to swap the rightmost point with the pivot point. From the leftmost point, if its distance is
greater than the pivot, we swap it with the store_index and increment the store_index. This continues
until all points larger than the pivot are essentially swapped to the left of the store index. We
then move the pivot to the store index and return the store index.

If k == store_index, then all elements to the right of the pivot point are the k closest points. If
k < store_index, we quickselect the left side of the pivot, else we quickselect the right side of the
pivot.

The solution is as follows:


  from math import dist
  from random import randint

  class Solution:
      def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
          distance = lambda x: dist((points[x][0], points[x][1]), (0, 0))

          def swap(i: int, j: int) -> None:
              points[i], points[j] = points[j], points[i]

          def partition(left: int, right: int, pivot_index: int) -> int:
              pivot_distance = distance(pivot_index)
              swap(pivot_index, right)

              store_index = left
              for i in range(left, right):
                  if distance(i) > pivot_distance:
                      swap(store_index, i)
                      store_index += 1

              swap(store_index, right)

              return store_index

          def quickselect(left: int, right: int, k_smallest: int) -> None:
              if left == right:
                  return

              pivot_index = randint(left, right)
              pivot_index = partition(left, right, pivot_index)

              if k_smallest == pivot_index:
                  return
              elif k_smallest < pivot_index:
                  quickselect(left, pivot_index - 1, k_smallest)
              else:
                  quickselect(pivot_index + 1, right, k_smallest)

          n = len(points)
          quickselect(0, n - 1, n - k)
          return points[n - k :]


_ Time Complexity:

  O(n) - The quickselect time complexity is O(n) on average.

_ Space Complexity:

  O(n) - Our recursive calls are O(n) in the worst case.