Sundram Sharan
Sundram Sharan

Reputation: 23

How "pi+1" statement does tail call elimination in quicksort function which is using the Lomuto Algorithm for partitioning?

/* QuickSort after tail call elimination */

#include<stdio.h>

A utility function to swap two elements

void swap(int* a, int* b)
{
    int t = *a;
    *a = *b;
    *b = t;
}

/* This function takes last element as pivot, places the pivot element at its correct position in sorted array, and places all smaller (smaller than pivot) to left of pivot and all greater elements to right of pivot. It is using the Lomuto partition algorithm. */

int partition (int arr[], int low, int high)
{
    int pivot = arr[high];    `// pivot`
    int i = (low - 1);  `// Index of smaller element`

    for (int j = low; j <= high- 1; j++)
    {
        `// If the current element is smaller than or equal to pivot `
        if (arr[j] <= pivot)
        {
            i++;    `// increment index of smaller element`
            swap(&arr[i], &arr[j]);
        }
    }
    swap(&arr[i + 1], &arr[high]);
    return (i + 1);
}

/* The main function that implements QuickSort arr[] --> Array to be sorted, low --> Starting index, high --> Ending index */

void quickSort(int arr[], int low, int high)
{
    while (low < high)
    {
        `/* pi is partitioning index, arr[p] is now at right place */`
        int pi = partition(arr, low, high);

        `// Separately sort elements before partition and after partition`
        quickSort(arr, low, pi - 1);

        low = pi+1;
    }
}

Function to print an array

void printArray(int arr[], int size)
{
    for (int i=0; i < size; i++)
        printf("%d ", arr[i]);
    printf("\n");
}

Driver program to test above functions

int main()
{
    int arr[] = {10, 7, 8, 9, 1, 5};
    int n = sizeof(arr)/sizeof(arr[0]);
    quickSort(arr, 0, n-1);
    printf("Sorted array: \n");
    printArray(arr, n);
    return 0;
}

Upvotes: 2

Views: 161

Answers (4)

chqrlie
chqrlie

Reputation: 145277

The function does remove the tail call manually. The classic implementation is:

void quickSort(int arr[], int low, int high)
{
    if (low < high) {
        /* pi is partitioning index, arr[pi] is now at right place */
        int pi = partition(arr, low, high);

        // Separately sort elements before partition and after partition
        quickSort(arr, low, pi - 1);
        quickSort(arr, pi + 1, high);
    }
}

As explained in detail by ikegami, the tail call can be removed by manually setting low = pi + 1 and looping to the initial test if (low < high), which can be expressed elegantly by turning the if into a while statement.

This optimisation is probably useless as modern compilers probably would generate the same code from the classic version.

Note however that this approach is not immune to deep recursion in pathological cases leading to stack overflow errors.

To keep stack usage under control (log2(n)), one would only recurse on the smaller half and iterate on the larger half:

void quickSort(int arr[], int low, int high)
{
    if (low < high) {
        /* pi is partitioning index, arr[pi] is now at right place */
        int pi = partition(arr, low, high);
        if (pi - low < high - pi) {
            quickSort(arr, low, pi - 1);
            quickSort(arr, pi + 1, high);  // tail call on the larger half
        } else {
            quickSort(arr, pi + 1, high);
            quickSort(arr, low, pi - 1);  // tail call on the larger half
        }
    }
}

Which gives this after manually removing the tail calls:

void quickSort(int arr[], int low, int high)
{
    while (low < high) {
        /* pi is partitioning index, arr[pi] is now at right place */
        int pi = partition(arr, low, high);
        if (pi - low < high - pi) {
            quickSort(arr, low, pi - 1);
            low = pi + 1;  // iterate on the larger half
        } else {
            quickSort(arr, pi + 1, high);
            high = pi - 1;  // iterate on the larger half
        }
    }
}

Upvotes: 0

rcgldr
rcgldr

Reputation: 28921

To minimize stack usage, the code needs to recurse on smaller partition, loop on larger partition, for stack space complexity of O(log(n)). Worst case time complexity remains O(n^2). Example with partition logic in QuickSort. Uses middle value for pivot by swapping it to [hi] so already sorted or reverse sorted data is not worst case.

void QuickSort(int a[], int lo, int hi)
{
    while (lo < hi){
        int t;
        int p = a[(lo+hi)/2];           /* use mid point for pivot */
        a[(lo+hi)/2]= a[hi];            /* swap with a[hi] */
        a[hi] = p;
        int i = lo;
        for (int j = lo; j < hi; ++j){  /* Lomuto partition */
            if (a[j] < p){
                t = a[i];
                a[i] = a[j];
                a[j] = t;
                ++i;
            }
        }
        t = a[i];
        a[i] = a[hi];
        a[hi] = t;
        if(i - lo <= hi - i){           /* recurse on smaller partiton, loop on larger */
            QuickSort(a, lo, i-1);
            lo = i+1;
        } else {
            QuickSort(a, i+1, hi);
            hi = i-1;
        }
    }
}

Upvotes: 0

ikegami
ikegami

Reputation: 386541

A simple quicksort implementation has two recursive calls.

From Wikipedia:

algorithm quicksort(A, lo, hi) is 
  // Ensure indices are in correct order
  if lo >= hi || lo < 0 then 
    return
    
  // Partition array and get the pivot index
  p := partition(A, lo, hi) 
      
  // Sort the two partitions
  quicksort(A, lo, p - 1) // Left side of pivot
  quicksort(A, p + 1, hi) // Right side of pivot

The second is a tail call and can thus be eliminated. This was done (through manual refacturing) in the code you posted.

The first cannot be eliminated (without introducing a vector/queue/stack).


The refactoring in detail

  1. Conversion from Wikipedia's pseudo-code to C:

    void quicksort( int *A, size_t lo, size_t hi ) {
       if ( lo >= hi )
          return;
    
       size_t p = partition( A, lo, hi );
       quicksort( A, lo, p - 1 );
       quicksort( A, p + 1, hi );
    }
    
  2. Replace tail call:

    void quicksort( int *A, size_t lo, size_t hi ) {
    quicksort:
       if ( lo >= hi )
          return;
    
       size_t p = partition( A, lo, hi );
       quicksort( A, lo, p - 1 );
       A = A;
       lo = p + 1;
       hi = hi;
       goto quicksort;
    }
    
  3. Remove useless statements:

    void quicksort( int *A, size_t lo, size_t hi ) {
    quicksort:
       if ( lo >= hi )
          return;
    
       size_t p = partition( A, lo, hi );
       quicksort( A, lo, p - 1 );
       lo = p + 1;
       goto quicksort;
    }
    
  4. Replace goto with a loop:

    void quicksort( int *A, size_t lo, size_t hi ) {
       while ( true ) {
          if ( lo >= hi )
             return;
    
          size_t p = partition( A, lo, hi );
          quicksort( A, lo, p - 1 );
          lo = p + 1;
       }
    }
    
  5. Replace the return with break:

    void quicksort( int *A, size_t lo, size_t hi ) {
       while ( true ) {
          if ( lo >= hi )
             break;
    
          size_t p = partition( A, lo, hi );
          quicksort( A, lo, p - 1 );
          lo = p + 1;
       }
    }
    
  6. Merge the while statement and the if statement:

    void quicksort( int *A, size_t lo, size_t hi ) {
       while ( lo < hi ) {
          size_t p = partition( A, lo, hi );
          quicksort( A, lo, p - 1 );
          lo = p + 1;
       }
    }
    

Upvotes: 0

Lundin
Lundin

Reputation: 214770

How “pi+1” statement does tail call elimination in quicksort function

It doesn't. All the mainstream x86 compilers fail to unroll the recursion from your code. For example gcc -O3 (godbolt):

quickSort:
        cmp     esi, edx
        jge     .L6
        push    r13
        mov     r13, rdi
        push    r12
        mov     r12d, edx
        push    rbp
        mov     ebp, esi
        push    rbx
        sub     rsp, 8
.L3:
        mov     esi, ebp
        mov     edx, r12d
        mov     rdi, r13
        call    partition
        mov     esi, ebp
        mov     rdi, r13
        mov     ebx, eax
        lea     edx, [rax-1]
        call    quickSort     // recursive call
        lea     ebp, [rbx+1]
        cmp     r12d, ebp
        jg      .L3
        add     rsp, 8
        pop     rbx
        pop     rbp
        pop     r12
        pop     r13
        ret

Upvotes: 0

Related Questions