Reputation: 39
Could someone please explain why does the order of output change when the nlargest
function is called with a key function with only the first arg,
import heapq
heap_arr = [(1, 'a'), (2, 'b'), (2, 'b'), (3, 'c'), (3, 'd')]
heapq.nlargest(2, x)
# Perfectly fine - OP is [(3, 'd'), (3, 'c')]
# This is similar to heapq.nlargest(2, x, key=lambda a: (a[0], a[1]))
heapq.nlargest(2, x, key=lambda a: a[0])
# OP is [(3, 'c'), (3, 'd')]... Why ??
Why does (3, 'c') appear first in the second example before (3, 'd'). The reason behind this question is that the order of tuples in the output list is important.
Upvotes: 0
Views: 1300
Reputation: 879471
Short Answer:
heapq.nlargest(2, heap_arr)
returns [(3, 'd'), (3, 'c')]
because
In [6]: (3, 'd') > (3, 'c')
Out[6]: True
heapq.nlargest(2, heap_arr, key=lambda a: a[0])
returns [(3, 'c'), (3, 'd')]
because heapq
, like sorted
, uses a stable sorting algorithm. Since the keys match (at the value 3) the stable sort returns the items in the order in which they appear in heap_arr
:
In [8]: heapq.nlargest(2, [(3, 'c'), (3, 'd')], key=lambda a: a[0])
Out[8]: [(3, 'c'), (3, 'd')]
In [9]: heapq.nlargest(2, [(3, 'd'), (3, 'c')], key=lambda a: a[0])
Out[9]: [(3, 'd'), (3, 'c')]
Longer Answer:
Per the docs, heapq.nlargest(n, iterable, key)
is equivalent to
sorted(iterable, key=key, reverse=True)[:n]
(though heapq.nlargest
computes its result in a different way).
Nevertheless, we can use this equivalency to check that heapq.nlargest
is behaving as we should expect:
import heapq
heap_arr = [(1, 'a'), (2, 'b'), (2, 'b'), (3, 'c'), (3, 'd')]
assert heapq.nlargest(2, heap_arr) == sorted(heap_arr, reverse=True)[:2]
assert heapq.nlargest(2, heap_arr, key=lambda a: a[0]) == sorted(heap_arr, key=lambda a: a[0], reverse=True)[:2]
So if you accept this equivalency, then you merely have to confirm the validity of
In [47]: sorted(heap_arr, reverse=True)
Out[47]: [(3, 'd'), (3, 'c'), (2, 'b'), (2, 'b'), (1, 'a')]
In [48]: sorted(heap_arr, key=lambda a: a[0], reverse=True)
Out[48]: [(3, 'c'), (3, 'd'), (2, 'b'), (2, 'b'), (1, 'a')]
When using key=lambda a: a[0]
, (3, 'c')
, (3, 'd')
are sorted according to
the same key value, 3. Because Python's sort is stable, two items with equals
keys (e.g. (3, 'c')
and (3, 'd')
) appear in the result in the same order as
they appear in heap_arr
.
Even longer answer:
To understand what is really going on under the hood, you could use a debugger or simply copy the code for heapq into a file and use print statements to study how the heap -- that is, the variable result
-- changes as the elements in the iterable are examined and possibly pushed onto the heap. Running this code:
def heapreplace(heap, item):
"""Pop and return the current smallest value, and add the new item.
This is more efficient than heappop() followed by heappush(), and can be
more appropriate when using a fixed-size heap. Note that the value
returned may be larger than item! That constrains reasonable uses of
this routine unless written as part of a conditional replacement:
if item > heap[0]:
item = heapreplace(heap, item)
"""
returnitem = heap[0] # raises appropriate IndexError if heap is empty
heap[0] = item
_siftup(heap, 0)
return returnitem
def heapify(x):
"""Transform list into a heap, in-place, in O(len(x)) time."""
n = len(x)
# Transform bottom-up. The largest index there's any point to looking at
# is the largest with a child index in-range, so must have 2*i + 1 < n,
# or i < (n-1)/2. If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
# j-1 is the largest, which is n//2 - 1. If n is odd = 2*j+1, this is
# (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
for i in reversed(range(n//2)):
_siftup(x, i)
# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos
# is the index of a leaf with a possibly out-of-order value. Restore the
# heap invariant.
def _siftdown(heap, startpos, pos):
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
def _siftup(heap, pos):
endpos = len(heap)
startpos = pos
newitem = heap[pos]
# Bubble up the smaller child until hitting a leaf.
childpos = 2*pos + 1 # leftmost child position
while childpos < endpos:
# Set childpos to index of smaller child.
rightpos = childpos + 1
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
childpos = rightpos
# Move the smaller child up.
heap[pos] = heap[childpos]
pos = childpos
childpos = 2*pos + 1
# The leaf at pos is empty now. Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
heap[pos] = newitem
_siftdown(heap, startpos, pos)
def nlargest(n, iterable, key=None):
"""Find the n largest elements in a dataset.
Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
"""
# Short-cut for n==1 is to use max()
if n == 1:
it = iter(iterable)
sentinel = object()
if key is None:
result = max(it, default=sentinel)
else:
result = max(it, default=sentinel, key=key)
return [] if result is sentinel else [result]
# When n>=size, it's faster to use sorted()
try:
size = len(iterable)
except (TypeError, AttributeError):
pass
else:
if n >= size:
return sorted(iterable, key=key, reverse=True)[:n]
# When key is none, use simpler decoration
if key is None:
it = iter(iterable)
result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
print('result: {}'.format(result))
if not result:
return result
heapify(result)
top = result[0][0]
order = -n
_heapreplace = heapreplace
for elem in it:
print('elem: {}'.format(elem))
if top < elem:
_heapreplace(result, (elem, order))
print('result: {}'.format(result))
top, _order = result[0]
order -= 1
result.sort(reverse=True)
return [elem for (elem, order) in result]
# General case, slowest method
it = iter(iterable)
result = [(key(elem), i, elem) for i, elem in zip(range(0, -n, -1), it)]
print('result: {}'.format(result))
if not result:
return result
heapify(result)
top = result[0][0]
order = -n
_heapreplace = heapreplace
for elem in it:
print('elem: {}'.format(elem))
k = key(elem)
if top < k:
_heapreplace(result, (k, order, elem))
print('result: {}'.format(result))
top, _order, _elem = result[0]
order -= 1
result.sort(reverse=True)
return [elem for (k, order, elem) in result]
heap_arr = [(1, 'a'), (2, 'b'), (2, 'b'), (3, 'c'), (3, 'd')]
nlargest(2, heap_arr)
print('-'*10)
nlargest(2, heap_arr, key=lambda a: a[0])
yields
# nlargest(2, heap_arr)
result: [((1, 'a'), 0), ((2, 'b'), -1)]
elem: (2, 'b')
result: [((2, 'b'), -2), ((2, 'b'), -1)]
elem: (3, 'c')
result: [((2, 'b'), -1), ((3, 'c'), -3)]
elem: (3, 'd')
result: [((3, 'c'), -3), ((3, 'd'), -4)] <---- compare this line (1)
----------
# nlargest(2, heap_arr, key=lambda a: a[0])
result: [(1, 0, (1, 'a')), (2, -1, (2, 'b'))]
elem: (2, 'b')
result: [(2, -2, (2, 'b')), (2, -1, (2, 'b'))]
elem: (3, 'c')
result: [(2, -1, (2, 'b')), (3, -3, (3, 'c'))]
elem: (3, 'd')
result: [(3, -4, (3, 'd')), (3, -3, (3, 'c'))] <---- with this line (2)
Remember that in a heap, heap[0]
is always the smallest item. And indeed,
In [45]: ((3, 'c'), -3) < ((3, 'd'), -4)
Out[45]: True
In [46]: (3, -4, (3, 'd')) < (3, -3, (3, 'c'))
Out[46]: True
This justifies the result we are seeing on lines (1) and (2).
When the tuples are undecorated, in the first case, (1),
(3, 'c')
ends up coming before (3, 'd')
while in the second case, (2),
the reverse happens.
So the behavior you are seeing arises from the fact that when key is None
, the elements from the iterable are placed in the heap as though they were tuples of the form (elem, order)
where the order
decrements by 1 with every _heapreplace
.
In contrast, when key is not None
, the tuples are of the form (k, order, elem)
where k
is key(elem)
. This difference in the form of the tuples leads to the difference you see in the result.
In the first case, elem
ends up controlling the order. In the second case,
since the k
values are equal, order
ends up controlling the order. The
purpose of order
is to break ties in a stable way. So ultimately we reach
the same conclusion as when we examined sorted(heap_arr, key=lambda a: a[0],
reverse=True)
. The order of (3, 'c')
and (3, 'd')
is the same as their
order in heap_arr
when the keys are equal.
If you want the ties in a[0]
to be broken by a
itself then use
In [53]: heapq.nlargest(2, heap_arr, key=lambda a: (a[0], a))
Out[53]: [(3, 'd'), (3, 'c')]
Upvotes: 4