S. N
S. N

Reputation: 3949

What is wrong with my implementation of Kruskal's algorithm using union-find data structure.

I am trying to implement the Kruskal's algorithm, and find the sum of the weights in the MST. I think my problem lays somewhere where I set the parent of each node, but I am not sure, because in small examples it works fine, however with big example it doesnt detect a cycle, and the final answer is wrong. so my find might be wrong, but I am not sure. Here is my code:

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Scanner;

public class Graph {
    private static int parent[];
    private static int numberOfNodes, numberOfEdges, weight;
    private static Node startNode, endNode;
    private static Graph g;
    private static ArrayList<Integer> listOfWeights = new ArrayList<Integer>();
    private static ArrayList<Edge> listOfEdges = new ArrayList<Edge>();
    private static ArrayList<Node> listOfNodes = new ArrayList<Node>();
    private static HashMap<Edge, Integer> distance = new HashMap<Edge, Integer>();
    private static HashMap<Edge, Integer> sortedMap = new HashMap<Edge, Integer>();
    //values = Integer
    private Scanner sc;

    public static void main(String[] args) throws FileNotFoundException {
        // TODO Auto-generated method stub
        g = new Graph();
    }

    public Graph() throws FileNotFoundException {

        sc = new Scanner(new File("data"));
        numberOfNodes = Integer.parseInt(sc.next());
        numberOfEdges = Integer.parseInt(sc.next());

        if(numberOfEdges > 0){
            for(int i = 0; i < numberOfEdges; i++) {
                //read the start point
                startNode = new Node(Integer.parseInt(sc.next()));
                //read the end point
                endNode = new Node(Integer.parseInt(sc.next()));
                //read the price per node
                weight = Integer.parseInt(sc.next());
                Edge e = new Edge(startNode,endNode);
                //set the weight per node
                e.setWeight(weight);
                //put them in a hashmap
                distance.put(e,e.getWeight());
                if(!listOfNodes.contains(startNode)){
                    listOfNodes.add(startNode);
                }
                if(!listOfNodes.contains(endNode)){
                    listOfNodes.add(endNode);
                }
                //System.out.println(distance.get(e));
            }
            System.out.println("without sort distance: " + distance.toString());
            for (Object key : distance.keySet()) {
                listOfEdges.add((Edge) key);
            }
            for (Object value : distance.values()) {
                listOfWeights.add((Integer) value);
            }
            sortedMap = sortHashMapByValuesD(distance);
            //System.out.println("list of nodes: "+ listOfNodes);
            parent = new int[listOfNodes.size()];
            System.out.println("sorted by weights: "+sortedMap.toString());
            System.out.println(kruskalAlgo(sortedMap));
        }
        else{
            System.out.println(0);
        }
    }

    public static void makeSet(int x){
        parent[x-1] = x;
    }

    public static int find(int x){
        //System.out.println("FIND ==> x: " + x + ".parent = " + parent[x-1] );
        if(parent[x-1] == x){
            return x;
        }
        return find(parent[x-1]);
    }

    public static void union(int x, int y){
        //System.out.println("parent[0]: "+parent[0]);
        parent[x-1] = y;
        System.out.println("x: " + x + " UNION parent[x-1]: " + parent[x-1] + " y " + y );
    }

    public static int kruskalAlgo(HashMap<Edge, Integer> s){
        parent[0] = 0;
        for(int i = 0; i < parent.length; i++){
            makeSet(listOfNodes.get(i).getId());
            //System.out.println("parent is: "+parent[i] + " for node"+ listOfNodes.get(i));
        }
        // for each edge (u,v) ∈ G, taken in increasing order by weight
        int min = 0;
        int edgeNumber = 0;
        for (Edge key : s.keySet()) {
            if(edgeNumber == listOfNodes.size()-1){
                //System.out.println("edgeNumber: "+ edgeNumber);
                //System.out.println("listOfNodes.size()-1: "+ (listOfNodes.size()-1));
                return min;
            }
            Node u = key.getFromNode();
            //System.out.println(u);
            Node v = key.getToNode();
            //System.out.println(v);
            if(find(u.getId()) != find (v.getId())){
                min += key.getWeight();
                union(u.getId(),v.getId());
                System.out.println(key + " weight is: " + key.getWeight());
                edgeNumber++;
            }
        }
        return min;
    }

    public static ArrayList<Edge> findSet(Node v){
        ArrayList<Edge> nodes = new ArrayList<Edge>();
        return nodes;
    }

    //make an ordered listed by increasing weights
    public LinkedHashMap<Edge, Integer> sortHashMapByValuesD(HashMap<Edge, Integer> newMap) {
        //list of edges    
        List<Edge> mapKeys = new ArrayList<Edge>(newMap.keySet());
        //list of nodes
        List<Integer> mapValues = new ArrayList<Integer>(newMap.values());
        //sort the nodes
        Collections.sort(mapValues);
        LinkedHashMap<Edge, Integer> sortedMap = new LinkedHashMap<Edge, Integer>();
        Iterator<Integer> valueIt = mapValues.iterator();
        while (valueIt.hasNext()) {
            Object val = valueIt.next();
            Iterator<Edge> keyIt = mapKeys.iterator();
            while (keyIt.hasNext()) {
                Object key = keyIt.next();
                String comp1 = newMap.get(key).toString();
                String comp2 = val.toString();

                if (comp1.equals(comp2)){
                    newMap.remove(key);
                    mapKeys.remove(key);
                    sortedMap.put((Edge)key, (Integer)val);
                    break;
               }
           }
        }
        return sortedMap;
    }
}

Upvotes: 1

Views: 2877

Answers (1)

Peter de Rivaz
Peter de Rivaz

Reputation: 33499

I think the problem may be in your implementation of union:

public static void union(int x, int y){
    parent[x-1] = y;
}

the problem is if x already has been joined into a set, it will already have a parent which you override.

The solution is to join the root of the two candidates instead of the leaf nodes:

public static void union(int x, int y){
    x=find(x);
    y=find(y);
    parent[x-1] = y;
}

By the way,a good description of this Disjoint-set algorithm, plus hints on making it more efficient via "union by rank" and "path compression" is on wikipedia at this page.

Upvotes: 2

Related Questions