Reputation: 4353
I am trying to run a graph classification problem in pytorch-geometric and I see that some of my graphs contain isolated nodes (which can cause problems). For example, my dataset is a list of pytorch data objects:
dataset = [graph1, graph2, graph3...]
where graph1 is a pytorch-geometric data object, containing the graph's structure, node features and label. I see that pytorch geometric ALREADY HAS A TRANSFORM for precisely this task, however it doesn't say anywhere how to apply it, as it's a class that takes no input.
Upvotes: 1
Views: 1050
Reputation: 26
import torch_geometric.transforms as T
transform = T.Compose([T.remove_isolated_nodes.RemoveIsolatedNodes()])
for graph in dataset:
graph = transform(graph)
Upvotes: 1
Reputation: 33
To do that, you can just use the remove_isolated_nodes
method from torch_geometric.utils
library. The code might look as follows:
for graph in dataset:
graph['edge_index'] = remove_isolated_nodes(graph['edge_index'])[0]
Upvotes: 1