Reputation: 1217
I was working on a leetcode problem (https://leetcode.com/problems/top-k-frequent-elements/) which is:
Given an integer array nums and an integer k, return the k most frequent elements. You may return the answer in any order.
I solved this using min-heap
(My time complexity calculations are in comment - do correct me if I did a mistake):
from collections import Counter
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1], x[0]) for x in c.items()])
# O(K)
result = list(islice(it, k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result, elem)
# O(K)
return [pair[1] for pair in result]
# O(K) + O(N) + O((N - K) log K) + O(K log K)
# if k < N :
# O(N log K)
Then I saw a solution using Bucket Sort
that suppose to beat the heap solution with O(N)
:
bucket = [[] for _ in nums]
# O(N)
c = collections.Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(itertools.chain(*bucket))[:k]
How do we compute the time complexity of the itertools.chain
call here?
Is it come from the fact that at most we will chain N
elements? Is that enough to deduce it is O(N)?
In any case, at least on leetcode test cases the first one has better performance - what can be the reason for that?
Upvotes: 0
Views: 1372
Reputation: 44108
Notwithstanding my comment using nlargest
seems to run more slowly than using heapify
, etc. (see below). But the Bucket Sort, at least for this input, definitely is more performant. It would also seem that with the Bucket Sort that creating the full list of num
elements to take the first k
elements does not cause too much of a penalty.
from collections import Counter
from heapq import nlargest
from itertools import chain
def most_frequent_1a(nums, k):
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1], x[0]) for x in c.items()])
# O(K)
result = list(islice(it, k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result, elem)
# O(K)
return [pair[1] for pair in result]
def most_frequent_1b(nums, k):
if k == len(nums):
return nums
c = Counter(nums)
return [pair[1] for pair in nlargest(k, [(x[1], x[0]) for x in c.items()])]
def most_frequent_2a(nums, k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(chain(*bucket))[:k]
def most_frequent_2b(nums, k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
# don't create full list:
i = 0
for elem in chain(*bucket):
yield elem
i += 1
if i == k:
break
import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums, 3))
print(most_frequent_1b(nums, 3))
print(most_frequent_2a(nums, 3))
print(list(most_frequent_2b(nums, 3)))
print(timeit.timeit(stmt='most_frequent_1a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums, 3))', number=10000, globals=globals()))
Prints:
[7, 723, 88]
[723, 88, 7]
[7, 88, 723]
[7, 88, 723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136
Upvotes: 0
Reputation: 51034
The time complexity of list(itertools.chain(*bucket))
is O(N) where N is the total number of elements in the nested list bucket
. The chain
function is roughly equivalent to this:
def chain(*iterables):
for iterable in iterables:
for item in iterable:
yield item
The yield
statement dominates the running time, is O(1), and executes N times, hence the result.
The reason your O(N log k) algorithm might end up being faster in practice is because log k is probably not very large; LeetCode says k is at most the number of distinct elements in the array, but I suspect for most of the test cases k is much smaller, and of course log k is smaller than that.
The O(N) algorithm probably has a relatively high constant factor because it allocates N lists and then randomly accesses them by index, resulting in a lot of cache misses; the append
operations may also cause a lot of those lists to be reallocated as they become larger.
Upvotes: 0