Raghav
Raghav

Reputation: 169

Disjoint Set implementation in C++

I came across this problem in an Online contest and I'm trying to solve it using Disjoint Set Data-structure.

Problem Definition:

Bob visits a nuclear power plant during his school excursion. He observes that there are n nuclear rods in the plant and the initial efficiency of the nuclear rods is 1. After a period of time nuclear rods start fusing with each other and combine to form a group. This process reduces the efficiency of the nuclear rods to square root of the size of the group. Bob, being a curious student, wants to know the total efficiency of the nuclear plant after some time. This is obtained by adding the efficiencies of the groups.

Initially all the rods belong to its own group of size 1. There are f fusions. If rod1 and rod2 get fused, it means their groups got fused.

Sample Input:

5 2

1 2

2 3

Sample Output:

3.73

Explanation:

n=5 fusions=2

group 1,2,3 => 1.73 (sqrt(3))

group 4 => 1

group 5 => 1

total = (1.73 + 1 + 1) = 3.73

My code:

#include <iostream>
#include <set>
#include <vector>
#include <stdio.h>
#include <math.h>
#include <iomanip>
using namespace std;

typedef long long int lli;

vector<lli> p,rank1,setSize;   /* p keeps track of the parent
                                * rank1 keeps track of the rank
                                * setSize keeps track of the size of the set. 
                                */ 

lli findSet(lli i) { return (p[i] == i) ? i : (p[i] = findSet(p[i])); } 

bool sameSet(lli x,lli y) { return findSet(x) == findSet(y); }


void union1(lli x,lli y) {      // union merges two sets.

    if(!sameSet(x,y)) {

        lli i = findSet(x), j = findSet(y);

        if(rank1[i] > rank1[j]) {
            p[j] = i;
            setSize[i] += setSize[j];           

        }

        else {
            p[i] = j;
            setSize[j] += setSize[i];
            if(rank1[i] == rank1[j])
                rank1[j]++;
        }
    }
}

int main() {

    freopen("input","r",stdin);

    lli n;
    cin >> n;                               //number of nuclear rods

    setSize.assign(n,1);                    //Initialize the setSize with 1 because every element is in its own set
    p.assign(n,0);          
    rank1.assign(n,0);                      //Initialize ranks with 0's.

    for(lli i = 0; i < n; i++) p[i] = i;    //Every set is distinct. Thus it is its own parent.

    lli f;
    cin >> f;                               //Number of fusions.

    while(f--){                 

        lli x,y;
        cin >> x >> y;                      //combine two rods
        union1(x,y);                        

    }   

    double ans; 

    set<lli> s (p.begin(),p.end());         //Get the representative of all the sets.

    for(lli i : s){     
        ans += sqrt(setSize[i]);            //sum the sqrt of all the members of that set.

    }

    printf("\n%.2f", ans);                  //display the answer in 2 decimal places.
}

The above code seems to work for all testcases but one.

The input is here for which my code fails.

The expected output is : 67484.82

My output : 67912.32

I can't really work out where I went wrong since the input is really huge.

Any help would really be appreciated. Thanks in advance.

Upvotes: 3

Views: 1314

Answers (1)

Raziman T V
Raziman T V

Reputation: 489

p holds the immediate parents of the elements and not their findSet values. Hence when you do set<lli> s (p.begin(),p.end()); you can have additional elements there.

There are two ways I can think of to deal with this:

  1. Insert findSet(i) into the set with a loop instead of directly putting p
  2. After you do setSize[i] += setSize[j], set setSize[j] = 0. This way, the intermediate parents will not contribute to the sum.

Upvotes: 1

Related Questions