Dan Lincan
Dan Lincan

Reputation: 1065

set get position of an element

I want to solve the following problem: given a vector of n elements, find the number of swaps the insertion sort algorithm needs to sort.

Ex:
n = 5
2 1 3 1 2
Answer: 4

Explanation(step by step for insertion sort algorithm):
initialy: 2 1 3 1 2
1 2 3 1 2 ; 1 swap 1( 1 goes left)
1 2 3 1 2 ; 0 swaps
1 1 2 3 2 ; 2 swaps ( 1 goes 2 pos left )
1 1 2 2 3 ; 1 swap ( 2 goes 1 pos left)

My solution

I keep the position of every item in the initial array so I can remove the from the set later based on value and position.(1st for loop)
Then I count the number of elements that are smaller than the current number add them to the counter and remove this element from the set. ( 2nd for loop )

As you can see, the problem is the std::distance which has linear complexity cause set has bidirectional iterators. How can I get O(1) complexity without having to implement my own tree?

int count_operations(vector<int> &v)
{
    set<pair<int, int>> s;
    // O(N * logN)
    for(int i = 0; i < (int) v.size(); ++i)
    {
        s.insert(make_pair(v[i], i));
    }
    int cnt = 0;
    // desired: O(N * log N) ; current O(N^2)
    for(int i = 0; i < (int) v.size(); ++i)
    {
        auto item = make_pair(v[i], i);
        auto it = s.find(item);
        int dist = distance(s.begin(), it);//O(N); I want O(1)
        s.erase(it);
        cnt += dist;
    }
    return cnt;
}

Upvotes: 0

Views: 1767

Answers (1)

Subhasis Das
Subhasis Das

Reputation: 1677

The problem is getting the rank of each element in a set, which can be done with an order statistic tree (using the pbds library in gnu libc++) as follows.

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <vector>
#include <utility>
using namespace std;
using namespace __gnu_pbds;

typedef tree<
    pair<int, int>, /* key type */
    null_mapped_type, /* value type */
    less< pair<int, int> >, /* comparison */
    rb_tree_tag, /* for having an rb tree */
    tree_order_statistics_node_update> order_set;

int count_ops(std::vector<int> &v)
{
    order_set s;
    int cnt = 0;
    /* O(N*log(N)) */
    for(int i = 0; i < v.size(); i++)
        s.insert(pair<int, int>(v[i], i));
    for(int i = 0; i < v.size(); i++)
    {
        /* finding rank is O(log(N)), so overall complexity is O(N*log(N)) */
        cnt += s.order_of_key(pair<int, int>(v[i], i));
        s.erase(pair<int, int>(v[i], i));
    }
    return cnt;
}

Upvotes: 1

Related Questions