Reputation: 698
I am trying to extend the elements of a TUDataset
dataset.
In particular, I have a dataset obtained via
dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)
I want to add a new vector-like feature to every entry of the dataset.
for i, current_g in enumerate(dataset):
nxgraph = nx.to_numpy_array(torch_geometric.utils.to_networkx(current_g) )
feature = do_something(nxgraph)
dataset[i].new_feature = feature
However, this code doesn't seem to work. As you can verify yourself, it's not possible to add attributes to an element of dataset
.
In [80]: dataset[2].test = 1
In [81]: dataset[2].test
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test
AttributeError: 'Data' object has no attribute 'test'
In [82]: dataset[2].__setattr__('test', 1)
In [83]: dataset[2].test
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test
AttributeError: 'Data' object has no attribute 'test'
An element in dataset
is a Data
from torch_geometric.data Data
.
I can create a new Data
element with all the features I want by using:
tmp=dataset[i].to_dict()
tmp['new_feature'] = feature
new_dataset[i]=torch_geometric.data.Data.from_dict(tmp)
However, I don't know how to create a TUDataset
dataset (Or the partent class of it) from a list of Data
elements. Do you know how?
Any idea on how to solve this problem? Thanks.
Upvotes: 0
Views: 751
Reputation: 451
One elegant way to reach your goal is to define your transformation.
from torch_geometric.transforms import BaseTransform
class Add_Node_Feature(BaseTransform):
def __init__(self, parameters):
self.paramters= paramters # parameters you need
def __call__(self, data: Data) -> Data:
node_feature = data.x
data.x = do_something(node_feature)
return data
Then, you can apply this transformation when loading the dataset. This way, the dataset is modified, and new features will be added.
import torch_geometric.transforms as T
dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)
dataset.transform = T.Compose([Add_Node_Feature()])
Upvotes: 2
Reputation: 698
The solution was very easy. In most of the cases you just need a list (and DataLoader
works with lists just fine):
dataset = [change_element(element) for element in dataset]
where change_element returns a new Data
element as described in the question.
Upvotes: 0