Reputation: 1065
I want to solve the following problem: given a vector of n elements, find the number of swaps the insertion sort algorithm needs to sort.
Ex:
n = 5
2 1 3 1 2
Answer: 4
Explanation(step by step for insertion sort algorithm):
initialy: 2 1 3 1 2
1 2 3 1 2 ; 1 swap 1( 1 goes left)
1 2 3 1 2 ; 0 swaps
1 1 2 3 2 ; 2 swaps ( 1 goes 2 pos left )
1 1 2 2 3 ; 1 swap ( 2 goes 1 pos left)
I keep the position of every item in the initial array so I can remove the from the set later based on value and position.(1st for loop)
Then I count the number of elements that are smaller than the current number add them to the counter and remove this element from the set. ( 2nd for loop )
As you can see, the problem is the std::distance which has linear complexity cause set has bidirectional iterators. How can I get O(1) complexity without having to implement my own tree?
int count_operations(vector<int> &v)
{
set<pair<int, int>> s;
// O(N * logN)
for(int i = 0; i < (int) v.size(); ++i)
{
s.insert(make_pair(v[i], i));
}
int cnt = 0;
// desired: O(N * log N) ; current O(N^2)
for(int i = 0; i < (int) v.size(); ++i)
{
auto item = make_pair(v[i], i);
auto it = s.find(item);
int dist = distance(s.begin(), it);//O(N); I want O(1)
s.erase(it);
cnt += dist;
}
return cnt;
}
Upvotes: 0
Views: 1767
Reputation: 1677
The problem is getting the rank of each element in a set, which can be done with an order statistic tree (using the pbds library in gnu libc++) as follows.
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <vector>
#include <utility>
using namespace std;
using namespace __gnu_pbds;
typedef tree<
pair<int, int>, /* key type */
null_mapped_type, /* value type */
less< pair<int, int> >, /* comparison */
rb_tree_tag, /* for having an rb tree */
tree_order_statistics_node_update> order_set;
int count_ops(std::vector<int> &v)
{
order_set s;
int cnt = 0;
/* O(N*log(N)) */
for(int i = 0; i < v.size(); i++)
s.insert(pair<int, int>(v[i], i));
for(int i = 0; i < v.size(); i++)
{
/* finding rank is O(log(N)), so overall complexity is O(N*log(N)) */
cnt += s.order_of_key(pair<int, int>(v[i], i));
s.erase(pair<int, int>(v[i], i));
}
return cnt;
}
Upvotes: 1