< Back




347. Top K Frequent Elements

We're asked to solve this problem faster than O(n log n). Completely doable, in two ways. One way is
with a heap. We use the Counter class from the collections package to count the frequency of each
number in the input. Then we use the heapq package's nlargest method to get the k most frequent
elements. This method has a time complexity of O(n log k), which is faster than O(n log n) when k is
smaller than n.

Another method to solve this is using Quickselect, which has an average time complexity of O(n). The
idea here is that we record the frequency of all the numbers, and we maintain an array of the unique
numbers. We then partition the array based on the frequency of the numbers. This partition algorithm
randomly selects a pivot and partitions the array into two parts: one with elements greater than the
pivot and one with elements less than the pivot.

Once the partition is complete, if the pivot's index is equal to n - k, all the elements to the
right of the pivot are the k most frequent elements. If the pivot's index is less than n - k, we
recurse on the right side of the pivot. If the pivot's index is greater than n - k, we recurse on
the left side of the pivot.

The solution is as follows:


from collections import Counter

class Solution:
    def topKFrequent(self, nums: List[int], k: int) -> List[int]:
        count = Counter(nums)
        unique = list(count.keys())

        def partition(left, right, pivot_index) -> int:
            pivot_frequency = count[unique[pivot_index]]
            unique[pivot_index], unique[right] = unique[right], unique[pivot_index]

            store_index = left
            for i in range(left, right):
                if count[unique[i]] < pivot_frequency:
                    unique[store_index], unique[i] = unique[i], unique[store_index]
                    store_index += 1

            unique[right], unique[store_index] = unique[store_index], unique[right]

            return store_index

        def quickselect(left, right, k_smallest) -> None:
            if left == right:
                return

            pivot_index = random.randint(left, right)
            pivot_index = partition(left, right, pivot_index)

            if k_smallest == pivot_index:
                 return
            elif k_smallest < pivot_index:
                quickselect(left, pivot_index - 1, k_smallest)
            else:
                quickselect(pivot_index + 1, right, k_smallest)

        n = len(unique)
        quickselect(0, n - 1, n - k)

        return unique[n - k:]


_ Time Complexity:

  O(n) - We count the frequency of each number in the input in O(n) time.

_ Space Complexity:

  O(n) - We store the frequency of each number in the input.