Karup
Karup

Reputation: 2079

Given a bipartite graph with node weight get ordered list of one type of node based on certain heuristic

Given a bi-partite graph with node weights for both type A and type B nodes as shown below:

enter image description here

I want to output an ordered list of type B nodes as defined by the following heuristic:

  1. For each node of type B we sum over node weights of type A that this node has an edge with and multiply the sum with its own node weight to get the node value.
  2. We then select the node from type B which has the highest value and append it to the output set S.
  3. We delete the selected node from type B and all the nodes it had an edge to from type A.
  4. Go back to step 1 until any node in type B is left with an edge to a node in type A.
  5. Append any remaining node of type B to the output set in order of its node weight.

The figure below shows an example:

enter image description here

For this example, the output set will be: (Y, Z, X)

The naive process will be to simply walk through this algorithm but assuming the bi-partite graph is huge, I am looking for the most efficient way to find the output set. Note, I just need the ordered list of type B nodes as output without the intermediate calculated values (eg. 50, 15, 2)

Upvotes: 2

Views: 340

Answers (2)

Mivik
Mivik

Reputation: 211

I provided a solution in C++ that basically resembles @ravenspoint 's idea. It maintains a heap and take the B node with the highest value each time. Here I used priority_queue instead of set cause the first one the much faster than the second one.


#include <chrono>
#include <iostream>
#include <queue>
#include <vector>

int nA, nB;
std::vector<int> A, B, sum;
std::vector<std::vector<int>> adjA, adjB;
inline std::vector<int> solve() {
    struct Node {
        // We store the value of the node `x` WHEN IT IS INSERTED
        // Modifying the value of the node `x` (sum) won't affect this Node basically
        int x, val;

        Node(int x): x(x), val(sum[x] * B[x]) {}

        bool operator<(const Node &t) const { return val == t.val? (B[x] < B[t.x]): (val < t.val); }
    };

    std::priority_queue<Node> q;
    std::vector<bool> delA(nA, false), delB(nB, false);
    std::vector<int> ret; ret.reserve(nB);

    for (int x = 0; x < nA; ++x)
        for (int y : adjA[x]) sum[y] += A[x];
    for (int y = 0; y < nB; ++y) q.emplace(y);
    while (!q.empty()) {
        const Node node = q.top(); q.pop();
        const int y = node.x;
        if (sum[y] * B[y] != node.val || delB[y]) // This means this Node is obsolete
            continue;
        delB[y] = true;
        ret.push_back(y);
        for (int x : adjB[y]) {
            if (delA[x]) continue;
            delA[x] = true;
            for (int ny : adjA[x]) {
                if (delB[ny]) continue;
                sum[ny] -= A[x];
                // This happens at most `m` time
                q.emplace(ny);
            }
        }
    }

    return ret;
}
int main() {
    std::cout << "Number of nodes in type A: "; std::cin >> nA;
    A.resize(nA); adjA.resize(nA);
    std::cout << "Weights of nodes in type A: ";
    for (int &v : A) std::cin >> v;

    std::cout << "Number of nodes in type B: "; std::cin >> nB;
    B.resize(nB); adjB.resize(nB); sum.resize(nB, 0);
    std::cout << "Weights of nodes in type B: ";
    for (int &v : B) std::cin >> v;

    int m;
    std::cout << "Number of edges: "; std::cin >> m;
    std::cout << "Edges: " << std::endl;
    for (int i = 0; i < m; ++i) {
        int x, y; std::cin >> x >> y;
        --x; --y;
        adjA[x].push_back(y);
        adjB[y].push_back(x);
    }

    auto st_time = std::chrono::steady_clock::now();
    auto ret = solve();
    auto en_time = std::chrono::steady_clock::now();
    std::cout << "Answer:";
    for (int v : ret) std::cout << ' ' << (v + 1);
    std::cout << std::endl;

    std::cout << "Took "
        << std::chrono::duration_cast<std::chrono::milliseconds>(en_time - st_time).count()
        << "ms" << std::endl;
}

I randomly generated some batches of data where nA = nB = 1e6, m = 2e6, and the program can always produce the answer in less than 800ms on my computer (not considering IO time, O2 enabled). The time complexity of this solution is O((m+n)log m) since the emplace called at most n+m time.

Sorry for my poor English. Feel free to point out my typos and mistakes.

Upvotes: 1

ravenspoint
ravenspoint

Reputation: 20596

This is a further refinement of the algo suggested by Dave in a comment. It minimizes the number of times a node value needs to be recalculated.

  1. run through step 1, placing the resulting B nodes in a max heap by val
  2. check the top node if any of its neighbors are deleted. If yes, recalculate and reinsert into heap. If no, add to output and delete neighbors.
  3. repeat 2 until all B are output

I have implemented this algorithm in C++ based on my PathFinder graph class. The code, running on a 1 million node graph with half a and half b nodes, each b node connected to two random a nodes, requires 1 second.

Here is the code

void cPathFinder::karup()
    {
        raven::set::cRunWatch aWatcher("karup");
        std::cout << "karup on " << nodeCount() << " node graph\n";
        std::vector<int> output;

        // calculate initial values of B nodes
        std::multimap<int, int> mapValueNode;
        for (auto &b : nodes())
        {
            if (b.second.myName[0] != 'b')
                continue;
            int value = 0;
            for (auto a : b.second.myLink)
            {
                value += node(a.first).myCost;
            }
            value *= b.second.myCost;
            mapValueNode.insert(std::make_pair(value, b.first));
        }

        // while not all B nodes output
        while (mapValueNode.size())
        {
            raven::set::cRunWatch aWatcher("select");

            // select node with highest value
            auto remove_it = --mapValueNode.end();
            int remove = remove_it->second;

            if (!remove_it->first)
            {
                /** all remaining nodes have zero value
                 * all the links from B nodes to A nodes have been removed
                 * output remaining nodes in order of decreasing node weight
                 */
                raven::set::cRunWatch aWatcher("Bunlinked");
                std::multimap<int, int> mapNodeValueNode;
                for (auto &nv : mapValueNode)
                {
                   mapNodeValueNode.insert( 
                       std::make_pair( 
                           node(nv.second).myCost,
                           nv.second ));
                }
                for( auto& nv : mapNodeValueNode )
                {
                    myPath.push_back( nv.second );
                }
                break;
            }

            bool OK = true;
            int value = 0;
            {
                raven::set::cRunWatch aWatcher("check");

                // check that no nodes providing value have been removed

                // std::cout << "checking neighbors of " << name(remove) << "\n";

                auto &vl = node(remove).myLink;
                for (auto it = vl.begin(); it != vl.end();)
                {
                    if (!myG.count(it->first))
                    {
                        // A neighbour has been removed
                        OK = false;
                        it = vl.erase(it);
                    }
                    else
                    {
                        // A neighbour remains
                        value += node(it->first).myCost;
                        it++;
                    }
                }
            }

            if (OK)
            {
                raven::set::cRunWatch aWatcher("store");
                // we have a node whose values is highest and valid

                // store result
                output.push_back(remove);

                // remove neighbour A nodes
                auto &ls = node(remove).myLink;
                for (auto &l : ls)
                {
                    myG.erase(l.first);
                }
                // remove the B node
                // std::cout << "remove " << name( remove ) << "\n";
                mapValueNode.erase(remove_it);
            }
            else
            {
                // replace old value with new
                raven::set::cRunWatch aWatcher("replace");
                value *= node(remove).myCost;
                mapValueNode.erase(remove_it);
                mapValueNode.insert(std::make_pair(value, remove));
            }
        }
    }

Here are the timing results

karup on 1000000 node graph
raven::set::cRunWatch code timing profile
Calls           Mean (secs)     Total           Scope
       1        1.16767 1.16767 karup
  581457        1.37921e-06     0.801951        select
  581456        4.71585e-07     0.274206        check
  564546        3.04042e-07     0.171646        replace
       1        0.153269        0.153269        Bunlinked
   16910        8.10422e-06     0.137042        store

Upvotes: 1

Related Questions