b-riley
b-riley

Reputation: 1

Pytorch Geometric: None being passed to Custom Data while collating

I'm just starting out with the torch_geometric library, and I'm working on making a custom dataset. However, I seem to be misunderstanding something about how the libraries data loader interacts with data. I have a minimal example below:

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch

class CustomData(Data):
    def __init__(self, a, b):
        
        # num_nodes = len(a)
        super().__init__(
            "a": a,
            "b": b,
            #"num_nodes": num_nodes,
        )

data = [CustomData(
    a = torch.ones(3),
    b = torch.ones(3)
)] * 10

loader = DataLoader(data, batch_size=2)

for batch in loader:
    print(batch)

The commented lines are something I'd like to accomplish - specifying the number of nodes so that the batch attribute of a batch of data is the correct size. However, if I uncomment them I get the following error

Traceback (most recent call last):
  File "c:\Users\name\Software\molgraph\projects\data.py", line 41, in <module>
    for i, batch in enumerate(loader):
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\loader\dataloader.py", line 27, in __call__
    return Batch.from_data_list(
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\data\batch.py", line 97, in from_data_list
    batch, slice_dict, inc_dict = collate(
                                  ^^^^^^^^
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\data\collate.py", line 56, in collate
    out = cls(_base_cls=data_list[0].__class__)  # type: ignore
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\name\Software\molgraph\.conda\Lib\site-packages\torch_geometric\data\batch.py", line 49, in __call__
    return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\name\Software\molgraph\projects\data.py", line 26, in __init__
    "num_nodes": len(a)
                 ^^^^^^
TypeError: object of type 'NoneType' has no len()
           ^^^^^^^^^^^^^^^^^^^^^

If I print out the a and b attributes when creating the data, there's no issue with their values and they are the correct tensors - it only happens during the loading of a batch. In this case a print statement just gives None for any attribute I pass to the data. Looking at the batch attribute of each batch sample, if I don't specify the number of nodes then the batch attribute is None. I tried looking in the source code to understand what's happening but I'm a bit lost.

I did notice that if I specify the edge_index and edge_attr values that the batch attribute is fine (though I still see None being passed to the data class for my inputs).

Upvotes: 0

Views: 41

Answers (1)

schmule
schmule

Reputation: 1

I had a similar issue with a custom implementation of Data. What worked for me was using the original Data class and adding additional arguments like so:

data = Data(
    a = torch.ones(3),
    b = torch.ones(3)
)

data.num_nodes = len(data.a)

Upvotes: 0

Related Questions