Reputation: 1
I'm just starting out with the torch_geometric
library, and I'm working on making a custom dataset. However, I seem to be misunderstanding something about how the libraries data loader interacts with data. I have a minimal example below:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch
class CustomData(Data):
def __init__(self, a, b):
# num_nodes = len(a)
super().__init__(
"a": a,
"b": b,
#"num_nodes": num_nodes,
)
data = [CustomData(
a = torch.ones(3),
b = torch.ones(3)
)] * 10
loader = DataLoader(data, batch_size=2)
for batch in loader:
print(batch)
The commented lines are something I'd like to accomplish - specifying the number of nodes so that the batch attribute of a batch of data is the correct size. However, if I uncomment them I get the following error
Traceback (most recent call last):
File "c:\Users\name\Software\molgraph\projects\data.py", line 41, in <module>
for i, batch in enumerate(loader):
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 631, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 54, in fetch
return self.collate_fn(data)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\loader\dataloader.py", line 27, in __call__
return Batch.from_data_list(
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\data\batch.py", line 97, in from_data_list
batch, slice_dict, inc_dict = collate(
^^^^^^^^
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\data\collate.py", line 56, in collate
out = cls(_base_cls=data_list[0].__class__) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\data\batch.py", line 49, in __call__
return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "c:\Users\name\Software\molgraph\projects\data.py", line 26, in __init__
"num_nodes": len(a)
^^^^^^
TypeError: object of type 'NoneType' has no len()
^^^^^^^^^^^^^^^^^^^^^
If I print out the a and b attributes when creating the data, there's no issue with their values and they are the correct tensors - it only happens during the loading of a batch. In this case a print statement just gives None for any attribute I pass to the data. Looking at the batch attribute of each batch sample, if I don't specify the number of nodes then the batch attribute is None. I tried looking in the source code to understand what's happening but I'm a bit lost.
I did notice that if I specify the edge_index and edge_attr values that the batch attribute is fine (though I still see None being passed to the data class for my inputs).
Upvotes: 0
Views: 41
Reputation: 1
I had a similar issue with a custom implementation of Data. What worked for me was using the original Data class and adding additional arguments like so:
data = Data(
a = torch.ones(3),
b = torch.ones(3)
)
data.num_nodes = len(data.a)
Upvotes: 0