White
White

Reputation: 345

How to visualize a torch_geometric graph in Python?

Let's consider as an example that I have the following adjacence matrix in coordinate format:

> edge_index.numpy() = array([[    0,     1,     0,   3,   2],
                              [    1,     0,     3,   2,   1]], dtype=int64)

which means that the node 0 is linked toward the node 1, and vice-versa, the node 0 is linked to 3 etc.

How to draw this graph as in networkx with nx.draw()?

Upvotes: 8

Views: 12569

Answers (2)

stupidMoewe
stupidMoewe

Reputation: 41

Another way would be to use the Explanation class from torch_geometric directly:

import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.explain import Explanation

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_mask = torch.tensor([1, 1, 0, 1], dtype=torch.float)

data = Data(x=x, edge_index=edge_index, edge_mask=edge_mask)

Explanation(data, edge_index=data.edge_index, edge_mask=data.edge_mask).visualize_graph(backend='networkx')

This also allows seeing the node ID

Upvotes: 1

Chao Shu
Chao Shu

Reputation: 219

import networkx as nx

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = torch_geometric.data.Data(x=x, edge_index=edge_index)
g = torch_geometric.utils.to_networkx(data, to_undirected=True)
nx.draw(g)

Upvotes: 21

Related Questions