Bluefarmer
Bluefarmer

Reputation: 5

How to add a comparator to a custom sort function

So I've implemented merge sort which, for all intends and purposes, could also be a custom sort function and I've started turning it into a template function.

Where I've run into a problem is when I wanted to add to possibility of passing a custom compare function in order to sort in different ways. (eg. std::greater and std::less or any custom one).

I've verified that the sorting algorithm works when I'd replace the ints by T. How would I add the custom compare function from here in order to also sort custom objects etc?

template <  typename T, 
            class Compare>
void merge( vector<T> &arr, int start, int mid, int end, Compare comp ) 
{
    int lptr = start; 
    int rptr = mid+1; 
    int tempptr = 0; 

    vector<T> temp( end - start + 1 ); 

    for ( int i = 0; i<temp.size(); i++)
    {
        if ( lptr > mid ) //done with left-section, just move the right elements
        {   
            temp[tempptr] = arr[rptr];
            rptr++;
        } else if ( rptr > end ) //done with right-section, just move the left elements
        {
            temp[tempptr] = arr[lptr];
            lptr++; 
        } else if ( comp( arr[rptr], arr[lptr] )) // right item < left item, move right item
        {
            temp[tempptr] = arr[rptr]; 
            rptr++; 
        } else          //otherwise left item < right item, move left item
        {
            temp[tempptr] = arr[lptr];
            lptr++; 
        }
        tempptr++;
    }

    for ( int i = 0; i<temp.size(); i++)
    {
        arr[start + i] = temp[i]; 
    }
}






template <  typename T, 
            class Compare>
void mergeSort( vector<T> &arr, int start, int end, Compare comp)
{   

    //if we're down to single elements, do nothing
    if ( start < end ){
        //call to right and left 'child' 
        int mid = (start + end) / 2; 

        mergeSort( arr, start, mid ); 
        mergeSort( arr, mid + 1, end );

        //call to merge
        merge( arr, start, mid, end ); 
    }
}



int main()
{   
    vector<float> arr = {7,8, 2, 6.6, 1, 4.1, 5, 3, 8, 9};
    cout << "before sorting:" << endl;
    for ( auto n : arr ) 
        cout << n << ", ";
    cout << endl;
    mergeSort( arr, 0, arr.size() - 1); 

    cout << "after sorting:" << endl;
    for ( auto n : arr ) 
        cout << n << ", ";
    cout << endl;

    return 0; 
};

Thanks in advance.

Upvotes: 0

Views: 2276

Answers (3)

Francis Cugler
Francis Cugler

Reputation: 7905

I think you are over complicating your implementation(s) for merge and merge sort. I've written the same function with 2 overloads and placed them in a namespace so that they don't conflict with the std library's version of merge. Take a look at my example to see what has been done.

#include <iostream>
#include <vector>
#include <algorithm>

namespace test {

// without comp predicate
template<class InputIt1, class InputIt2, class OutputIt>
OutputIt merge( InputIt1 first1, InputIt1 last1,
                InputIt2 first2, InputIt2 last2,
                OutputIt d_first ) {

    for( ; first1 != last1; ++d_first ) {
        if( first2 == last2 ) {
            return std::copy( first1, last1, d_first );
        }
        if( *first2 < *first1 ) {
            *d_first = *first2;
            ++first2;
        } else {
            *d_first = *first1;
            ++first1;
        }
    }
    return std::copy( first2, last2, d_first );
}

// with comp predicate
template<class InputIt1, class InputIt2, class OutputIt, class Compare>
OutputIt merge( InputIt1 first1, InputIt1 last1,
                InputIt2 first2, InputIt2 last2,
                OutputIt d_first, Compare comp ) {

    for( ; first1 != last1; ++d_first ) {
        if( first2 == last2 ) {
            return std::copy( first1, last1, d_first );
        }
        // This is were I replaced the default `< operator` with the `Compare` predicate.
        if( comp( *first2, *first1 ) ) {
            *d_first = *first2;
            ++first2;
        } else {
            *d_first = *first1;
            ++first1;
        }
    }
    return std::copy( first2, last2, d_first );
}

} // namespace test

int main() {
    std::vector<int> v1{ 1,3,5,7,9 };
    std::vector<int> v2{ 2,4,6,8 };
    std::vector<int> v3;

    // print this way
    std::cout << "v1 : ";
    for( auto& v : v1 ) {
        std::cout << v << ' ';
    }
    std::cout << '\n';

    // or print this way
    std::cout << "v2 : ";
    std::copy( v2.begin(), v2.end(), std::ostream_iterator<int>( std::cout, " " ) );
    std::cout << '\n';

    // Merge without binary predicate comp function - functor etc.
    test::merge( v1.begin(), v1.end(), 
                 v2.begin(), v2.end(), 
                 std::back_inserter( v3 ) );

    // using std's functors std::less - std::greater
    test::merge( v1.begin(), v1.end(), 
                 v2.begin(), v2.end(), 
                 std::back_inserter( v3 ), 
                 std::less<int>() );

    test::merge( v1.begin(), v1.end(), 
                 v2.begin(), v2.end(), 
                 std::back_inserter( v3 ), 
                 std::greater<int>() );

    // using lambda's as predicate compare objects.
    test::merge( v1.begin(), v1.end(), 
                 v2.begin(), v2.end(), 
                 std::back_inserter( v3 ), 
                 []( auto&& a, auto&& b ) { return a < b; } );

    test::merge( v1.begin(), v1.end(), 
                 v2.begin(), v2.end(), 
                 std::back_inserter( v3 ), 
                 []( auto&& a, auto&& b ) { return a > b; } );    

    std::cout << "v3 : ";
    std::copy( v3.begin(), v3.end(), std::ostream_iterator<int>( std::cout, " " ) );
    std::cout << '\n';    

    std::cout << "\nPress any key to quit.\n";
    std::cin.get();
    return 0;
}

These 2 overload functions do exactly what you are looking for; to merge and sort while being able to choose what predicate comp function, functor, etc. to use in a single function.

Using the template class InputIt notation simplifies a lot of the internal parts of the functions; not having to keep track of sizes, index positions, indexing into arrays, etc.

All we really need to do is go through a for loop using the appropriate comparison operators on the Input Iterators then decide when to either use std::copy(...) or to assign an element from either first2 or first1 then increment our iterator. Finally after the for loop finishes we want to use and return std::copy(...). The first overload without the compare predicate by default uses the < operator where the second overload takes a predicate.

This also allows for you to pass any kind of container that has a begin and end, iterator making it very generic, modular and portable while trying to maintain the best practices of modern c++.

Upvotes: 0

Y.S
Y.S

Reputation: 311

As Sam Varshavchik stated, replace your comparing operator with your comparing function. Meaning this:

 if ( lptr > mid ) //done with left-section, just move the 

Changes to this:

       if ( comp(lptr,mid) ) //done with left-section, just move the 

Btw you have a not handled case:

template <  typename T, 
            class Compare>
void mergeSort( vector<T> &arr, int start, int end, Compare comp)
{   

    //if we're down to single elements, do nothing
    if ( start < end ){
        //call to right and left 'child' 
        int mid = (start + end) / 2; 

        mergeSort( arr, start, mid ); 
        mergeSort( arr, mid + 1, end );

        //call to merge
        merge( arr, start, mid, end ); 
    }
    else{ throw "Not handled case";}
}

Upvotes: 1

Raxvan
Raxvan

Reputation: 6505

Considering you have a class or struct CustomType

Pre c++11

struct CustomCompare
{
    bool operator ()(const CustomType& a, const CustomType& b)
    {
        return a.Watever < b.Watever;
    }
};

//usage
merge(vector<CustomType> ..., CustomCompare());

Post c++11, using lambdas:

auto CustomCompare = [](const CustomType & a,const CustomType& b)
{
    return a. .... ;
};
//usage
merge(vector<CustomType> ..., CustomCompare);

There is third option:

You can use std::less but there must exist an operator < that takes your CustomType as arguments

Example:

struct CustomType
{
    //...
    bool operator < (const CustomType& other)const
    {
        return this->Whatever < other.Whatever;
    }
};

And you can specialize std::less:

namespace std
{
    template <>
    struct less <CustomType>
    {
        bool operator()(const CustomType & a, const CustomType & b)
        {
            return ...
        }
    };
}

Upvotes: 1

Related Questions