Arvind Raghu
Arvind Raghu

Reputation: 458

How do I make this Merge Sort function a generator (Python)?

So I understand how to make the merge sort algorithm in Python 3, and this is my code below for implementing the function:

def x(arr):
    for i in mergeSort(arr):
        yield from i

def mergeSort(arr):
    if len(arr) > 1:
        middle = len(arr) // 2
        left = arr[:middle]
        right = arr[middle:]

        mergeSort(left)
        mergeSort(right)

        a = 0
        b = 0
        c = 0

        while a < len(left) and b < len(right):
            if left[a] < right[b]:
                arr[c] = left[a]
                a += 1
            else:
                arr[c] = right[b]
                b += 1
            c += 1

        while a < len(left):
            arr[c] = left[a]
            a += 1
            c += 1

        while b < len(right):
            arr[c] = right[b]
            b += 1
            c += 1

for i in mergeSort([6,3,8,7,4,1,2,9,5,0]):
    print(i)

The gist of it is that the function returns the array sorted at the end. However, I am trying to build a sorting visualiser, and to do so, I need to make this function yield the array whenever a change is made so you can see the bars switch - so I need to make it a generator, but all attempts I've made to do so haven't worked. How could I modify my code to make this function a generator?

Thanks.

Upvotes: 2

Views: 1027

Answers (2)

trincot
trincot

Reputation: 350771

First you will need to make sure that a deeply recursive execution can actually report about the state of the whole list. With your current set up that is not possible, since the recursive function only gets to see a small slice of the array.

So, in order to fix that situation, don't pass slices with the recursive call, but pass start/end indices instead, giving access to the same arr to all function execution contexts.

Then you can yield arr after each merge. The code that makes the recursive call should use yield from.

I adapted your code only to apply the above idea:

def mergeSort(arr):
    # arr is a unique list that all levels in the recursion tree can access:

    def mergeSortRec(start, end):  # separate function that can take start/end indices
        if end - start > 1:
            middle = (start + end) // 2

            yield from mergeSortRec(start, middle)  # don't provide slice, but index range
            yield from mergeSortRec(middle, end)
            left = arr[start:middle]
            right  = arr[middle:end]

            a = 0
            b = 0
            c = start

            while a < len(left) and b < len(right):
                if left[a] < right[b]:
                    arr[c] = left[a]
                    a += 1
                else:
                    arr[c] = right[b]
                    b += 1
                c += 1

            while a < len(left):
                arr[c] = left[a]
                a += 1
                c += 1

            while b < len(right):
                arr[c] = right[b]
                b += 1
                c += 1
            
            yield arr

    yield from mergeSortRec(0, len(arr))  # call inner function with start/end arguments

for i in mergeSort([6,3,8,7,4,1,2,9,5,0]):
    print(i)

For the example list, the output is the following:

[3, 6, 8, 7, 4, 1, 2, 9, 5, 0]
[3, 6, 8, 4, 7, 1, 2, 9, 5, 0]
[3, 6, 4, 7, 8, 1, 2, 9, 5, 0]
[3, 4, 6, 7, 8, 1, 2, 9, 5, 0]
[3, 4, 6, 7, 8, 1, 2, 9, 5, 0]
[3, 4, 6, 7, 8, 1, 2, 9, 0, 5]
[3, 4, 6, 7, 8, 1, 2, 0, 5, 9]
[3, 4, 6, 7, 8, 0, 1, 2, 5, 9]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

You could decide to yield also the start/end indices, so that the consumer of the iterator knows where exactly the algorithm was manipulating the list. So then change:

        yield arr

to:

        yield arr, start, end

With that change, the output becomes:

([3, 6, 8, 7, 4, 1, 2, 9, 5, 0], 0, 2)
([3, 6, 8, 4, 7, 1, 2, 9, 5, 0], 3, 5)
([3, 6, 4, 7, 8, 1, 2, 9, 5, 0], 2, 5)
([3, 4, 6, 7, 8, 1, 2, 9, 5, 0], 0, 5)
([3, 4, 6, 7, 8, 1, 2, 9, 5, 0], 5, 7)
([3, 4, 6, 7, 8, 1, 2, 9, 0, 5], 8, 10)
([3, 4, 6, 7, 8, 1, 2, 0, 5, 9], 7, 10)
([3, 4, 6, 7, 8, 0, 1, 2, 5, 9], 5, 10)
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 0, 10)

Upvotes: 4

PApostol
PApostol

Reputation: 2292

Maybe something like this:

def x(arr):
    for i in mergeSort(arr):
        yield i

def mergeSort(arr):
    if len(arr) > 1:
        middle = len(arr) // 2
        left = arr[:middle]
        right = arr[middle:]

        mergeSort(left)
        mergeSort(right)

        a = 0
        b = 0
        c = 0

        while a < len(left) and b < len(right):
            if left[a] < right[b]:
                arr[c] = left[a]
                a += 1
            else:
                arr[c] = right[b]
                b += 1
            c += 1

        while a < len(left):
            arr[c] = left[a]
            a += 1
            c += 1

        while b < len(right):
            arr[c] = right[b]
            b += 1
            c += 1
        return arr

# Entry point
generator = x([6,3,8,7,4,1,2,9,5,0])
print(next(generator)) # prints 0
print(next(generator)) # prints 1

# print the remaining elements
for i in generator:
  print(i)

Output:

0
1
2
3
4
5
6
7
8
9

Note that a shorter recursive implementation of merge sort you could use could be the following:

def merge_sort(mylist):
    if len(mylist) < 2:
      return mylist

    less = []
    equal = []
    greater = []
    
    n = int(len(mylist)/2)
    pivot = mylist[n]

    for x in mylist:
      if x < pivot:
          less.append(x)
      elif x == pivot:
          equal.append(x)
      elif x > pivot:
          greater.append(x)

    return merge_sort(less) + equal + merge_sort(greater)

Upvotes: 1

Related Questions