user1692261
user1692261

Reputation: 1217

Top K Frequent Elements - time complexity: Bucket Sort vs Heap

I was working on a leetcode problem (https://leetcode.com/problems/top-k-frequent-elements/) which is:

Given an integer array nums and an integer k, return the k most frequent elements. You may return the answer in any order.

I solved this using min-heap (My time complexity calculations are in comment - do correct me if I did a mistake):

        from collections import Counter
        
        if k == len(nums):
            return nums
        
        # O(N)
        c = Counter(nums)
        
        it = iter([(x[1], x[0]) for x in c.items()])
        
        # O(K)
        result = list(islice(it, k))
        heapify(result)
        
        # O(N-K)
        for elem in it:
            # O(log K)
            heappushpop(result, elem)
            
        # O(K)
        return [pair[1] for pair in result]
    
    # O(K) + O(N) + O((N - K) log K) + O(K log K)
    # if k < N :
    #   O(N log K)

Then I saw a solution using Bucket Sort that suppose to beat the heap solution with O(N):

        bucket = [[] for _ in nums]

        # O(N)
        c = collections.Counter(nums)

        # O(d) where d is the number of distinct numbers. d <= N
        for num, freq in c.items():
            bucket[-freq].append(num)
            
        # O(?)
        return list(itertools.chain(*bucket))[:k]

How do we compute the time complexity of the itertools.chain call here? Is it come from the fact that at most we will chain N elements? Is that enough to deduce it is O(N)?

In any case, at least on leetcode test cases the first one has better performance - what can be the reason for that?

Upvotes: 0

Views: 1372

Answers (2)

Booboo
Booboo

Reputation: 44108

Notwithstanding my comment using nlargest seems to run more slowly than using heapify, etc. (see below). But the Bucket Sort, at least for this input, definitely is more performant. It would also seem that with the Bucket Sort that creating the full list of num elements to take the first k elements does not cause too much of a penalty.

from collections import Counter
from heapq import nlargest
from itertools import chain

def most_frequent_1a(nums, k):
    if k == len(nums):
        return nums

    # O(N)
    c = Counter(nums)

    it = iter([(x[1], x[0]) for x in c.items()])

    # O(K)
    result = list(islice(it, k))
    heapify(result)

    # O(N-K)
    for elem in it:
        # O(log K)
        heappushpop(result, elem)

    # O(K)
    return [pair[1] for pair in result]

def most_frequent_1b(nums, k):        
    if k == len(nums):
        return nums

    c = Counter(nums)        
    return [pair[1] for pair in nlargest(k, [(x[1], x[0]) for x in c.items()])]


def most_frequent_2a(nums, k):
    bucket = [[] for _ in nums]

    # O(N)
    c = Counter(nums)

    # O(d) where d is the number of distinct numbers. d <= N
    for num, freq in c.items():
        bucket[-freq].append(num)

    # O(?)
    return list(chain(*bucket))[:k]


def most_frequent_2b(nums, k):
    bucket = [[] for _ in nums]

    # O(N)
    c = Counter(nums)

    # O(d) where d is the number of distinct numbers. d <= N
    for num, freq in c.items():
        bucket[-freq].append(num)

    # O(?)
    # don't create full list:
    i = 0
    for elem in chain(*bucket):
        yield elem
        i += 1
        if i == k:
            break

import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums, 3))
print(most_frequent_1b(nums, 3))
print(most_frequent_2a(nums, 3))
print(list(most_frequent_2b(nums, 3)))
print(timeit.timeit(stmt='most_frequent_1a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums, 3))', number=10000, globals=globals()))

Prints:

[7, 723, 88]
[723, 88, 7]
[7, 88, 723]
[7, 88, 723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136

Upvotes: 0

kaya3
kaya3

Reputation: 51034

The time complexity of list(itertools.chain(*bucket)) is O(N) where N is the total number of elements in the nested list bucket. The chain function is roughly equivalent to this:

def chain(*iterables):
    for iterable in iterables:
        for item in iterable:
            yield item

The yield statement dominates the running time, is O(1), and executes N times, hence the result.


The reason your O(N log k) algorithm might end up being faster in practice is because log k is probably not very large; LeetCode says k is at most the number of distinct elements in the array, but I suspect for most of the test cases k is much smaller, and of course log k is smaller than that.

The O(N) algorithm probably has a relatively high constant factor because it allocates N lists and then randomly accesses them by index, resulting in a lot of cache misses; the append operations may also cause a lot of those lists to be reallocated as they become larger.

Upvotes: 0

Related Questions