Rajat Mishra
Rajat Mishra

Reputation: 1

How do we use PyTorch Geometric's RGCNConv model for Heterogenous Datasets?

I have been using PyTorch Geometric for link prediction on a HeteroData object that I have created using external datasets. I want to use RGCNConv for Link Prediction but have been unable to shape my model in a way that it accepts HeteroData object as input. I have been following the steps that are mentioned in this tutorial.

My HeteroData object is as follows after some processing:

HeteroData(
  admission={ x=[1024] },
  medicine={ x=[1024] },
  diagnosis={ x=[1024] },
  procedure={ x=[1024] },
  (admission, prescribed_to, medicine)={
    edge_index=[2, 1024], edge_weight=1,
  },
  (admission, diagnosed_with, diagnosis)={
    edge_index=[2, 1024], edge_weight=1,
  },
  (admission, procedure_done, procedure)={
    edge_index=[2, 1024], edge_weight=1,
  },
  (medicine, interacts_with, medicine)={
    edge_index=[2, 2048], edge_weight=-1,
  },
  (medicine, rev_prescribed_to, admission)={
    edge_index=[2, 1024], edge_weight=1,
  },
  (diagnosis, rev_diagnosed_with, admission)={
    edge_index=[2, 1024], edge_weight=1,
  },
  (procedure, rev_procedure_done, admission)={
    edge_index=[2, 1024], edge_weight=1,
  }
)

And one of the models I have been using is as follows:

class RelationPredictionModel(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim, num_relations, num_layers, dropout_rate=0.5):
    super(RelationPredictionModel, self).__init__()

    # Define RGCNConv layers
    self.convs = nn.ModuleList()
    for _ in range(num_layers):
      conv = HeteroConv({ # Using the Heterogeneous Convolution Wrapper
        ('admission', 'prescribed_to', 'medicine'): RGCNConv(in_channels=input_dim, out_channels=hidden_dim, num_relations=num_relations),
        ... # similar code for each relation
        ('procedure', 'rev_procedure_done', 'admission'): RGCNConv(in_channels=input_dim, out_channels=hidden_dim, num_relations=num_relations)
      }, aggr='sum')
      self.convs.append(conv)

    self.relu = nn.ReLU()

    self.dropout = nn.Dropout(p=dropout_rate)

  def forward(self, x_dict, edge_index_dict, edge_types):
    for conv in self.convs:
      x_dict = conv(x_dict, edge_index_dict, edge_types)
      x_dict = {key: x.relu() for key, x in x_dict.items()}

    x_dict = self.dropout(x_dict)

    return x_dict

When training my model as follows:

model = RelationPredictionModel2(input_dim, hidden_dim, output_dim, num_relations, num_layers)
model = model.to(device)

for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  for batch_data in train_loader:
    batch_data = batch_data.to(device)
    optimizer.zero_grad()

    x_dict = batch_data.x_dict
    edge_index_dict = batch_data.edge_index_dict
    edge_types = batch_data.edge_types

    # Forward pass
    output = model(x_dict, edge_index_dict, edge_types)
    ...

I get the following error:

TypeError                                 Traceback (most recent call last)
<ipython-input-46-a33be3f75ca2> in <cell line: 19>()
     35 
     36     # Forward pass
---> 37     output = model(x_dict, edge_index_dict, edge_types)
     38 
     39     # Compute loss

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

<ipython-input-25-30ce90f7d49d> in forward(self, x_dict, edge_index_dict, edge_types)
     36     # Graph convolutional layers
     37     for conv in self.convs:
---> 38       x_dict = conv(x_dict, edge_index_dict, edge_types)
     39       x_dict = {key: x.relu() for key, x in x_dict.items()}
     40 

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/hetero_conv.py in forward(self, *args_dict, **kwargs_dict)
    125                 if edge_type in value_dict:
    126                     has_edge_level_arg = True
--> 127                     args.append(value_dict[edge_type])
    128                 elif src == dst and src in value_dict:
    129                     args.append(value_dict[src])

TypeError: list indices must be integers or slices, not tuple

I have been unable to find a way to resolve these errors and using to_hetero has also led to the below error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 274, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.5", line 16, in forward
    edge_type__user__to__artist = edge_type_dict.get(('user', 'to', 'artist'), None)
AttributeError: 'list' object has no attribute 'get'

Call using an FX-traced Module, line 16 of the traced Module's generated forward function:
    edge_type_dict = torch_geometric_nn_to_hetero_transformer_get_dict(edge_type);  edge_type = None
    edge_type__user__to__artist = edge_type_dict.get(('user', 'to', 'artist'), None)

Hope someone provides some solution for passing HeteroData to these models.

PS: Sorry if the post is long. First time posting here. Trying to be thorough.

Upvotes: 0

Views: 207

Answers (1)

user12138762
user12138762

Reputation: 81

I don't know how to use RCGConv with a HeteroData object. That said, its forward method accepts an argument edge_type for specifying which type of relation each edge represents, hence we can make it work with an ordinary Data object. Just enumerate the edge types. Here's a dummy example:

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from random import random
from torch_geometric.nn import RGCNConv

num_features = 8

class mydataset(Dataset):
    def __init__(self, n):
        self.n = n
        self.gen = get_fake_graph()
        
    def __len__(self):
        return self.n
    
    def __getitem__(self, n):
        return next(iter(self.gen))
    
        
def get_fake_graph():
    while True:
        data = Data()
        num_v = torch.randint(8, 33, (1,))[0].item()
        num_e = torch.randint(4, 13, (1,))[0].item()

        data.x = torch.rand(num_v, num_features)
        data.edge_index = torch.stack((
            torch.randint(0, num_v, (num_e,)), 
            torch.randint(0, num_v, (num_e,))
        ))
        data.edge_type = torch.randint(0, 2, (num_e,))
        data.y = [1 if random() < 0.5 else 0 for i in range(num_v)]
        yield data
    
ds = mydataset(20)
dl = DataLoader(ds, batch_size=4)
t = next(iter(dl))

conv = RGCNConv(in_channels=num_features, out_channels=16, num_relations=2)
print(conv(t.x, t.edge_index, t.edge_type))

Upvotes: 0

Related Questions