user60143
user60143

Reputation: 1

Efficient set union and intersection in C++

Given two sets set1 and set2, I need to compute the ratio of their intersection by their union. So far, I have the following code:

double ratio(const set<string>& set1, const set<string>& set2)
{
    if( set1.size() == 0 || set2.size() == 0 )
        return 0;

    set<string>::const_iterator iter;
    set<string>::const_iterator iter2;
    set<string> unionset;

    // compute intersection and union
    int len = 0;
    for (iter = set1.begin(); iter != set1.end(); iter++) 
    {
        unionset.insert(*iter);
        if( set2.count(*iter) )
            len++;
    }
    for (iter = set2.begin(); iter != set2.end(); iter++) 
        unionset.insert(*iter);

    return (double)len / (double)unionset.size();   
}

It seems to be very slow (I'm calling the function about 3M times, always with different sets). The python counterpart, on the other hand, is way much faster

def ratio(set1, set2):
    if not set1 or not set2:
        return 0
    return len(set1.intersection(set2)) / len(set1.union(set2))

Any idea about how to improve the C++ version (possibly, not using Boost)?

Upvotes: 0

Views: 2104

Answers (2)

Jarod42
Jarod42

Reputation: 217145

It can be done in linear time, without new memory:

double ratio(const std::set<string>& set1, const std::set<string>& set2)
{
    if (set1.empty() || set2.empty()) {
        return 0.;
    }
    std::set<string>::const_iterator iter1 = set1.begin();
    std::set<string>::const_iterator iter2 = set2.begin();
    int union_len = 0;
    int intersection_len = 0;
    while (iter1 != set1.end() && iter2 != set2.end()) 
    {
        ++union_len;
        if (*iter1 < *iter2) {
            ++iter1;
        } else if (*iter2 < *iter1) {
            ++iter2;
        } else { // *iter1 == *iter2
            ++intersection_len;
            ++iter1;
            ++iter2;
        }
    }
    union_len += std::distance(iter1, set1.end());
    union_len += std::distance(iter2, set2.end());
    return static_cast<double>(intersection_len) / union_len;
}

Upvotes: 2

user2357112
user2357112

Reputation: 280311

You don't actually need to construct the union set. In Python terms, len(s1.union(s2)) == len(s1) + len(s2) - len(s1.intersection(s2)); the size of the union is the sum of the sizes of s1 and s2, minus the number of elements counted twice, which is the number of elements in the intersection. Thus, you can do

for (const string &s : set1) {
    len += set2.count(s);
}
return ((double) len) / (set1.size() + set2.size() - len)

Upvotes: 1

Related Questions