Reputation: 375
Question: How to retain the node ordering/labels when converting a graph from networkx
to pytorch geometric?
Code: (to be run in Google Colab)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
from torch.nn import Linear
import torch.nn.functional as F
torch.__version__
# install pytorch geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
from torch_geometric.nn import GCNConv
from torch_geometric.utils.convert import to_networkx, from_networkx
# Make the networkx graph
G = nx.Graph()
# Add some cars
G.add_nodes_from([
('Ford', {'y': 0, 'Name': 'Ford'}),
('Lexus', {'y': 1, 'Name': 'Lexus'}),
('Peugot', {'y': 2, 'Name': 'Peugot'}),
('Mitsubushi', {'y': 3, 'Name': 'Mitsubishi'}),
('Mazda', {'y': 4, 'Name': 'Mazda'}),
])
# Relabel the nodes
remapping = {x[0]: i for i, x in enumerate(G.nodes(data = True))}
G = nx.relabel_nodes(G, remapping, copy=False)
# Add some edges --> A = [(0, 1, 0, 1, 1), (1, 0, 1, 1, 0), (0, 1, 0, 0, 1), (1, 1, 0, 0, 0), (1, 0, 1, 0, 0)] as the adjacency matrix
G.add_edges_from([
(0, 1), (0, 3), (0, 4),
(1, 2), (1, 3),
(2, 1), (2, 4),
(3, 0), (3, 1),
(4, 0), (4, 2)
])
# Convert the graph into PyTorch geometric
pyg_graph = from_networkx(G)
pyg_graph.edge_index
When I print the edge indices in the last line of the code, I get different answers each time I run it. Most importantly, I am looking to consistently get the same (correct) answer whereby each node numbering is retained from networkx:
tensor([[0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4],
[4, 2, 4, 2, 3, 0, 1, 1, 4, 0, 1, 3]])
The form of this edge index tensor is:
For the node ids to be retained, we would expect node 0 to appear three times in the first (source) list instead of just twice.
Is there any way for me to force PyTorch Geometric to copy over the node ids?
Thanks
[EDIT] One possible work-around I have is using the following bit of code which is able to produce edge index and weight tensors for PyTorch geometric
# Create a dictionary of the mappings from company --> node id
mapping_dict = {x: i for i, x in enumerate(list(G.nodes()))}
# Get the number of nodes
num_nodes = len(mapping_dict)
# Now create a source, target, and edge list for PyTorch geometric graph
edge_source_list = []
edge_target_list = []
edge_weight_list = []
# iterate through all the edges
for e in G.edges():
# first element of tuple is appended to source edge list
edge_source_list.append(mapping_dict[e[0]])
# last element of tuple is appended to target edge list
edge_target_list.append(mapping_dict[e[1]])
# add the edge weight to the edge weight list
edge_weight_list.append(1)
# now create full edge lists for pytorch geometric - undirected edges need to be defined in both directions
full_source_list = edge_source_list + edge_target_list # full source list
full_target_list = edge_target_list + edge_source_list # full target list
full_weight_list = edge_weight_list + edge_weight_list # full edge weight list
print(len(edge_source_list), len(edge_target_list), len(full_source_list))
# now convert these to torch tensors
edge_index_tensor = torch.LongTensor( np.concatenate([ [np.array(full_source_list)], [np.array(full_target_list)]] ))
edge_weight_tensor = torch.FloatTensor(np.array(full_weight_list))
Upvotes: 5
Views: 1490
Reputation: 16571
It seems this issue was resolved in the comments (the solution proposed by @Sparky05 is to use copy=True
, which is the default for nx.relabel_nodes
), but below is the explanation for why the node order is changed.
When copy=False
is passed, nx.relabel_nodes
will re-add the nodes to the graph in the order they appear in the set of keys of remapping
dict. The relevant lines in the code are here:
def _relabel_inplace(G, mapping):
old_labels = set(mapping.keys())
new_labels = set(mapping.values())
if len(old_labels & new_labels) > 0:
# skip codes for labels sets that overlap
else:
# non-overlapping label sets
nodes = old_labels
# skip lines
for old in nodes: # this is now in the set order
By using set
the order of nodes is modified, so to preserve the order the non-overlapping label sets should be treated as:
else:
# non-overlapping label sets
nodes = mapping.keys()
A related PR is submitted here.
Upvotes: 3