Manu
Manu

Reputation: 39

Order of output changes for heapq.nlargest with a key function (Python)

Could someone please explain why does the order of output change when the nlargest function is called with a key function with only the first arg,

import heapq
heap_arr = [(1, 'a'), (2, 'b'), (2, 'b'), (3, 'c'), (3, 'd')]

heapq.nlargest(2, x)
# Perfectly fine - OP is [(3, 'd'), (3, 'c')]
# This is similar to heapq.nlargest(2, x, key=lambda a: (a[0], a[1]))

heapq.nlargest(2, x, key=lambda a: a[0])
# OP is [(3, 'c'), (3, 'd')]... Why ??

Why does (3, 'c') appear first in the second example before (3, 'd'). The reason behind this question is that the order of tuples in the output list is important.

Upvotes: 0

Views: 1300

Answers (1)

unutbu
unutbu

Reputation: 879471

Short Answer:

heapq.nlargest(2, heap_arr) returns [(3, 'd'), (3, 'c')] because

In [6]: (3, 'd') > (3, 'c')
Out[6]: True

heapq.nlargest(2, heap_arr, key=lambda a: a[0]) returns [(3, 'c'), (3, 'd')] because heapq, like sorted, uses a stable sorting algorithm. Since the keys match (at the value 3) the stable sort returns the items in the order in which they appear in heap_arr:

In [8]: heapq.nlargest(2, [(3, 'c'), (3, 'd')], key=lambda a: a[0])
Out[8]: [(3, 'c'), (3, 'd')]

In [9]: heapq.nlargest(2, [(3, 'd'), (3, 'c')], key=lambda a: a[0])
Out[9]: [(3, 'd'), (3, 'c')]

Longer Answer:

Per the docs, heapq.nlargest(n, iterable, key) is equivalent to

sorted(iterable, key=key, reverse=True)[:n]

(though heapq.nlargest computes its result in a different way). Nevertheless, we can use this equivalency to check that heapq.nlargest is behaving as we should expect:

import heapq
heap_arr = [(1, 'a'), (2, 'b'), (2, 'b'), (3, 'c'), (3, 'd')]

assert heapq.nlargest(2, heap_arr) == sorted(heap_arr, reverse=True)[:2]

assert heapq.nlargest(2, heap_arr, key=lambda a: a[0]) == sorted(heap_arr, key=lambda a: a[0], reverse=True)[:2]

So if you accept this equivalency, then you merely have to confirm the validity of

In [47]: sorted(heap_arr, reverse=True)
Out[47]: [(3, 'd'), (3, 'c'), (2, 'b'), (2, 'b'), (1, 'a')]

In [48]: sorted(heap_arr, key=lambda a: a[0], reverse=True)
Out[48]: [(3, 'c'), (3, 'd'), (2, 'b'), (2, 'b'), (1, 'a')]

When using key=lambda a: a[0], (3, 'c'), (3, 'd') are sorted according to the same key value, 3. Because Python's sort is stable, two items with equals keys (e.g. (3, 'c') and (3, 'd')) appear in the result in the same order as they appear in heap_arr.


Even longer answer:

To understand what is really going on under the hood, you could use a debugger or simply copy the code for heapq into a file and use print statements to study how the heap -- that is, the variable result -- changes as the elements in the iterable are examined and possibly pushed onto the heap. Running this code:

def heapreplace(heap, item):
    """Pop and return the current smallest value, and add the new item.

    This is more efficient than heappop() followed by heappush(), and can be
    more appropriate when using a fixed-size heap.  Note that the value
    returned may be larger than item!  That constrains reasonable uses of
    this routine unless written as part of a conditional replacement:

        if item > heap[0]:
            item = heapreplace(heap, item)
    """
    returnitem = heap[0]    # raises appropriate IndexError if heap is empty
    heap[0] = item
    _siftup(heap, 0)
    return returnitem

def heapify(x):
    """Transform list into a heap, in-place, in O(len(x)) time."""
    n = len(x)
    # Transform bottom-up.  The largest index there's any point to looking at
    # is the largest with a child index in-range, so must have 2*i + 1 < n,
    # or i < (n-1)/2.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
    # j-1 is the largest, which is n//2 - 1.  If n is odd = 2*j+1, this is
    # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
    for i in reversed(range(n//2)):
        _siftup(x, i)

# 'heap' is a heap at all indices >= startpos, except possibly for pos.  pos
# is the index of a leaf with a possibly out-of-order value.  Restore the
# heap invariant.
def _siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem


def _siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2*pos + 1    # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        # Move the smaller child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    _siftdown(heap, startpos, pos)


def nlargest(n, iterable, key=None):
    """Find the n largest elements in a dataset.

    Equivalent to:  sorted(iterable, key=key, reverse=True)[:n]
    """

    # Short-cut for n==1 is to use max()
    if n == 1:
        it = iter(iterable)
        sentinel = object()
        if key is None:
            result = max(it, default=sentinel)
        else:
            result = max(it, default=sentinel, key=key)
        return [] if result is sentinel else [result]

    # When n>=size, it's faster to use sorted()
    try:
        size = len(iterable)
    except (TypeError, AttributeError):
        pass
    else:
        if n >= size:
            return sorted(iterable, key=key, reverse=True)[:n]

    # When key is none, use simpler decoration
    if key is None:
        it = iter(iterable)
        result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
        print('result: {}'.format(result))
        if not result:
            return result
        heapify(result)
        top = result[0][0]
        order = -n
        _heapreplace = heapreplace
        for elem in it:
            print('elem: {}'.format(elem))
            if top < elem:
                _heapreplace(result, (elem, order))
                print('result: {}'.format(result))
                top, _order = result[0]
                order -= 1
        result.sort(reverse=True)
        return [elem for (elem, order) in result]

    # General case, slowest method
    it = iter(iterable)
    result = [(key(elem), i, elem) for i, elem in zip(range(0, -n, -1), it)]
    print('result: {}'.format(result))
    if not result:
        return result
    heapify(result)
    top = result[0][0]
    order = -n
    _heapreplace = heapreplace
    for elem in it:
        print('elem: {}'.format(elem))
        k = key(elem)
        if top < k:
            _heapreplace(result, (k, order, elem))
            print('result: {}'.format(result))
            top, _order, _elem = result[0]
            order -= 1
    result.sort(reverse=True)
    return [elem for (k, order, elem) in result]


heap_arr = [(1, 'a'), (2, 'b'), (2, 'b'), (3, 'c'), (3, 'd')]

nlargest(2, heap_arr)
print('-'*10)
nlargest(2, heap_arr, key=lambda a: a[0]) 

yields

# nlargest(2, heap_arr)
result: [((1, 'a'), 0), ((2, 'b'), -1)]
elem: (2, 'b')
result: [((2, 'b'), -2), ((2, 'b'), -1)]
elem: (3, 'c')
result: [((2, 'b'), -1), ((3, 'c'), -3)]
elem: (3, 'd')
result: [((3, 'c'), -3), ((3, 'd'), -4)]      <---- compare this line (1)
----------

# nlargest(2, heap_arr, key=lambda a: a[0])
result: [(1, 0, (1, 'a')), (2, -1, (2, 'b'))]
elem: (2, 'b')
result: [(2, -2, (2, 'b')), (2, -1, (2, 'b'))]
elem: (3, 'c')
result: [(2, -1, (2, 'b')), (3, -3, (3, 'c'))]
elem: (3, 'd')
result: [(3, -4, (3, 'd')), (3, -3, (3, 'c'))] <---- with this line (2)

Remember that in a heap, heap[0] is always the smallest item. And indeed,

In [45]: ((3, 'c'), -3) < ((3, 'd'), -4)
Out[45]: True

In [46]: (3, -4, (3, 'd')) < (3, -3, (3, 'c'))
Out[46]: True

This justifies the result we are seeing on lines (1) and (2). When the tuples are undecorated, in the first case, (1), (3, 'c') ends up coming before (3, 'd') while in the second case, (2), the reverse happens.

So the behavior you are seeing arises from the fact that when key is None, the elements from the iterable are placed in the heap as though they were tuples of the form (elem, order) where the order decrements by 1 with every _heapreplace. In contrast, when key is not None, the tuples are of the form (k, order, elem) where k is key(elem). This difference in the form of the tuples leads to the difference you see in the result.

In the first case, elem ends up controlling the order. In the second case, since the k values are equal, order ends up controlling the order. The purpose of order is to break ties in a stable way. So ultimately we reach the same conclusion as when we examined sorted(heap_arr, key=lambda a: a[0], reverse=True). The order of (3, 'c') and (3, 'd') is the same as their order in heap_arr when the keys are equal.

If you want the ties in a[0] to be broken by a itself then use

In [53]: heapq.nlargest(2, heap_arr, key=lambda a: (a[0], a))
Out[53]: [(3, 'd'), (3, 'c')]

Upvotes: 4

Related Questions