td12345
td12345

Reputation: 1

Turning a list of PyG Data objects into a PyG Dataset?

I have a python list of torch_geometric.data.Data objects (each one representing a graph). There is no easy way for me to access original raw files for this data: I just have the list. I need to turn this list of Data objects into a torch_geometric.data.InMemoryDataset or torch_geometric.data.Dataset object in order to integrate it with a larger code base which I did not write. How do I do this?

To be clear, I know that one can use a list of Data objects to make a torch_geometric.data.DataLoader object. But, I specifically need a Dataset object, not a DataLoader object, as the larger code base does some additional processing steps on Dataset objects before turning them into loaders.

I don't understand why PyG makes this so difficult. Should there not be a very easy way to do this?

I tried using a simple CustomDataset class

class CustomDataset(InMemoryDataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

and it gave me a KeyIndex error when trying to get the Data object at index 0. I also tried a version of the above code where the super class is Dataset as opposed to InMemoryDataset but I couldn't figure out how to make the collate method work.

Upvotes: 0

Views: 329

Answers (1)

Ruben Band
Ruben Band

Reputation: 19

I made some edits to the CustomDataset code you provided, which might help you:

class CustomDataset(InMemoryDataset):
    def __init__(self, listOfDataObjects):
        super().__init__()
        self.data, self.slices = self.collate(listOfDataObjects)
    
    def __len__(self):
        return len(self.slices)
    
    def __getitem__(self, idx):
        sample = self.get(idx)
        return sample

A list of Data objects is inefficient to store. That's why collation is done most of the time at the end of the initialization. To retrieve items from the data, you can't index the collated Data object, this will result in a KeyIndexError like you experienced. However, Pytorch Geometric InMemoryDataset implemented a get method that does the indexing for you! They use a separate method that can handle the collated objects and uses the slices provided to separate the indexed object from the collated object.

Upvotes: 0

Related Questions