Trung Le
Trung Le

Reputation: 483

Unexpected behavior when passing an attribute as torch.LongTensor in Pytorch Geometric

I have a Hetero Graph dataset where I need to store a scalar number namely relation for each graph. The DataLoader works fine if I pass the value of relation as an int or torch.tensor but when I pass relation as torch.LongTensor, the value of relation on each batch becomes unpredictable. This is how I create the dataset:

data = HeteroData()
# for example, relation = 0
data['relation'] = int(relation)
data['relation'] = torch.tensor(relation)
data['relation'] = torch.LongTensor(relation)

For example, with batch_size = 8, when passing relation as int or torch.tensor, then a batch looks like this

HeteroDataBatch(
  relation=[8],
  entity={
    x=[12661, 15],
    batch=[12661],
    ptr=[9],
  },
...

Notice that relation is a vector of length 8 which is expected. But when passing relation as int or torch.LongTensor (batch_size is still 8), the batch looks like this

HeteroDataBatch(
  relation=[39],
  entity={
    x=[12661, 15],
    batch=[12661],
    ptr=[9],
  },

Relation is now a vector of length 39, when I print out the actual value of batch['relation'], they are all zeros. I wonder what caused this unexpected behavior.

Upvotes: 0

Views: 24

Answers (0)

Related Questions