Reputation: 483
I have a Hetero Graph dataset where I need to store a scalar number namely relation
for each graph. The DataLoader works fine if I pass the value of relation
as an int
or torch.tensor
but when I pass relation
as torch.LongTensor
, the value of relation
on each batch becomes unpredictable.
This is how I create the dataset:
data = HeteroData()
# for example, relation = 0
data['relation'] = int(relation)
data['relation'] = torch.tensor(relation)
data['relation'] = torch.LongTensor(relation)
For example, with batch_size = 8
, when passing relation as int
or torch.tensor
, then a batch looks like this
HeteroDataBatch(
relation=[8],
entity={
x=[12661, 15],
batch=[12661],
ptr=[9],
},
...
Notice that relation is a vector of length 8 which is expected. But when passing relation as int or torch.LongTensor (batch_size is still 8), the batch looks like this
HeteroDataBatch(
relation=[39],
entity={
x=[12661, 15],
batch=[12661],
ptr=[9],
},
Relation is now a vector of length 39, when I print out the actual value of batch['relation']
, they are all zeros. I wonder what caused this unexpected behavior.
Upvotes: 0
Views: 24