Reputation: 53
i have written a recursive randomized quick sort function as below:
def randomized_quick_sort(a, l, r):
if l >= r:
return
k = random.randint(l, r)
a[l], a[k] = a[k], a[l]
m1, m2 = partition3(a, l, r)
randomized_quick_sort(a,l,m1-1)
randomized_quick_sort(a,m2+1,r)
the partition function used is given below which partitions a list into three parts, less than pivot, equal to pivot, and greater than pivot where pivot is the first element in the input list.
def partition3(a, l, r):
x = a[l]
less, equal, greater = [], [], []
for val in a[l:r+1]:
if val < x:
less.append(val)
if val == x:
equal.append(val)
if val > x:
greater.append(val)
a[l:r+1] = less + equal + greater
m1 = len(less)
m2 = m1 + len(equal) - 1
return m1, m2
if i run this quicksort function several times on a simple input such as
a = [2,2,3,3]
randomized_quick_sort(a,0,len(a)-1)
after only a few trials i get a "maximum recursion depth exceeded" error. Please help!
Upvotes: 2
Views: 171
Reputation: 137
One of the properties of quicksort algorithm is that it is an inplace sorting algorithm i.e. it doesn't take extra space to sort the given list. You can keep track of the index in the input list and swap the elements to do the sort inplace. Here's an example solution
import random
def partition(arr, start, end):
pivot = arr[end]
ix = start
for i in range(start, end):
if arr[i] <= pivot:
arr[i], arr[ix] = arr[ix], arr[i]
ix += 1
arr[ix], arr[end] = arr[end], arr[ix]
return ix
def quick_sort(arr, start, end):
if start > end: return
rand_num = random.randint(start, end)
arr[rand_num], arr[end] = arr[end], arr[rand_num]
ix = partition(arr, start, end)
quick_sort(arr, start, ix-1)
quick_sort(arr, ix+1, end)
arr = [2,4,7,8,9,1,3,5,6,12,32]
quick_sort(arr, 0, len(ans)-1)
output:
[1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 32]
Upvotes: 0
Reputation: 3491
This is actually pretty close, but I recommend testing def partition3(a, l, r)
by itself. You'll see that the values it is returning don't really make sense.
However, with a small change, we can get it to work:
m1 = len(less)
should be:
m1 = len(less) + l # l for left, not 1 for one
You don't want m1 to just be the length of the items in less
because if you had been comparing the 9th to the 11th item you'd return 1 when you mean to return 10.
Also, in general, try to avoid single letter variable names (especially l
which is easy to confuse for 1). It makes it hard to read and hard for people unfamiliar with your code to see what is happening.
Upvotes: 1