Qubix
Qubix

Reputation: 4353

How can one use the RemoveIsolatedNodes transform in Pytorch Geometric?

I am trying to run a graph classification problem in pytorch-geometric and I see that some of my graphs contain isolated nodes (which can cause problems). For example, my dataset is a list of pytorch data objects:

dataset = [graph1, graph2, graph3...] 

where graph1 is a pytorch-geometric data object, containing the graph's structure, node features and label. I see that pytorch geometric ALREADY HAS A TRANSFORM for precisely this task, however it doesn't say anywhere how to apply it, as it's a class that takes no input.

Upvotes: 1

Views: 1050

Answers (2)

mohammadreza
mohammadreza

Reputation: 26

import torch_geometric.transforms as T
transform = T.Compose([T.remove_isolated_nodes.RemoveIsolatedNodes()])
for graph in dataset:
   graph = transform(graph)

Upvotes: 1

LucTuc
LucTuc

Reputation: 33

To do that, you can just use the remove_isolated_nodes method from torch_geometric.utils library. The code might look as follows:

for graph in dataset:
    graph['edge_index'] = remove_isolated_nodes(graph['edge_index'])[0]

Upvotes: 1

Related Questions