Rocky the Owl
Rocky the Owl

Reputation: 375

How to retain node ordering when converting graph from networkx to pytorch geometric?

Question: How to retain the node ordering/labels when converting a graph from networkx to pytorch geometric?

Code: (to be run in Google Colab)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

import torch
from torch.nn import Linear
import torch.nn.functional as F
torch.__version__

# install pytorch geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cpu.html

from torch_geometric.nn import GCNConv
from torch_geometric.utils.convert import to_networkx, from_networkx

# Make the networkx graph
G = nx.Graph()

# Add some cars 
G.add_nodes_from([
      ('Ford', {'y': 0, 'Name': 'Ford'}),
      ('Lexus', {'y': 1, 'Name': 'Lexus'}),
      ('Peugot', {'y': 2, 'Name': 'Peugot'}),
      ('Mitsubushi', {'y': 3, 'Name': 'Mitsubishi'}),
      ('Mazda', {'y': 4, 'Name': 'Mazda'}),
])

# Relabel the nodes
remapping = {x[0]: i for i, x in enumerate(G.nodes(data = True))}

G = nx.relabel_nodes(G, remapping, copy=False)

# Add some edges --> A = [(0, 1, 0, 1, 1), (1, 0, 1, 1, 0), (0, 1, 0, 0, 1), (1, 1, 0, 0, 0), (1, 0, 1, 0, 0)] as the adjacency matrix
G.add_edges_from([
                  (0, 1), (0, 3), (0, 4),
                  (1, 2), (1, 3),
                  (2, 1), (2, 4), 
                  (3, 0), (3, 1),
                  (4, 0), (4, 2)
])

# Convert the graph into PyTorch geometric
pyg_graph = from_networkx(G)

pyg_graph.edge_index

When I print the edge indices in the last line of the code, I get different answers each time I run it. Most importantly, I am looking to consistently get the same (correct) answer whereby each node numbering is retained from networkx:

tensor([[0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4],
        [4, 2, 4, 2, 3, 0, 1, 1, 4, 0, 1, 3]])

The form of this edge index tensor is:

For the node ids to be retained, we would expect node 0 to appear three times in the first (source) list instead of just twice.

Is there any way for me to force PyTorch Geometric to copy over the node ids?

Thanks

[EDIT] One possible work-around I have is using the following bit of code which is able to produce edge index and weight tensors for PyTorch geometric

# Create a dictionary of the mappings from company --> node id
mapping_dict = {x: i for i, x in enumerate(list(G.nodes()))}

# Get the number of nodes
num_nodes = len(mapping_dict)

# Now create a source, target, and edge list for PyTorch geometric graph
edge_source_list = []
edge_target_list = []
edge_weight_list = []

# iterate through all the edges
for e in G.edges():
  # first element of tuple is appended to source edge list
  edge_source_list.append(mapping_dict[e[0]])

  # last element of tuple is appended to target edge list
  edge_target_list.append(mapping_dict[e[1]])

  # add the edge weight to the edge weight list
  edge_weight_list.append(1) 


# now create full edge lists for pytorch geometric - undirected edges need to be defined in both directions

full_source_list = edge_source_list + edge_target_list      # full source list
full_target_list = edge_target_list + edge_source_list      # full target list
full_weight_list = edge_weight_list + edge_weight_list      # full edge weight list

print(len(edge_source_list), len(edge_target_list), len(full_source_list))

# now convert these to torch tensors
edge_index_tensor = torch.LongTensor( np.concatenate([ [np.array(full_source_list)], [np.array(full_target_list)]] ))
edge_weight_tensor = torch.FloatTensor(np.array(full_weight_list))

Upvotes: 5

Views: 1490

Answers (1)

SultanOrazbayev
SultanOrazbayev

Reputation: 16571

It seems this issue was resolved in the comments (the solution proposed by @Sparky05 is to use copy=True, which is the default for nx.relabel_nodes), but below is the explanation for why the node order is changed.

When copy=False is passed, nx.relabel_nodes will re-add the nodes to the graph in the order they appear in the set of keys of remapping dict. The relevant lines in the code are here:

def _relabel_inplace(G, mapping):
    old_labels = set(mapping.keys())
    new_labels = set(mapping.values())
    if len(old_labels & new_labels) > 0:
        # skip codes for labels sets that overlap
    else:
        # non-overlapping label sets
        nodes = old_labels

    # skip lines
    for old in nodes: # this is now in the set order

By using set the order of nodes is modified, so to preserve the order the non-overlapping label sets should be treated as:

    else:
        # non-overlapping label sets
        nodes = mapping.keys()

A related PR is submitted here.

Upvotes: 3

Related Questions