Zhang Fei
Zhang Fei

Reputation: 11

How to solve the unexpected forward() parameter problem of GNNExplainer

When I use link prediction, I encounter an unknown parameter problem.

# Define GraphSAGE model
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout_rate=0.5):
super(SAGE, self).__init__()

        self.initial_bn = BatchNorm(in_channels)
    
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
    
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns.append(BatchNorm(hidden_channels))
    
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(BatchNorm(hidden_channels))
    
        self.convs.append(SAGEConv(hidden_channels, out_channels))
    
        self.dropout = dropout_rate
    
    def reset_parameters(self):
        self.initial_bn.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
    
    def forward(self, x, edge_index):
        x = self.initial_bn(x)
    
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x
    
    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        logits = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
        return torch.sigmoid(logits)

My dataset is like this

Data(x=[9730, 1545], edge_index=[2, 242828], train_mask=[9730], val_mask=[9730], test_mask=[9730], pos_edge_label=[7336], pos_edge_label_index=[2, 7336], neg_edge_label=[7336], neg_edge_label_index=[2, 7336])

It is slightly different from the official examples for cora dataset

Data(x=[2708, 1433], edge_index=[2, 9774], y=[2708], train_mask=[2708], test_mask=[2708], num_classes=7, edge_label=[1084], edge_label_index=[2, 1084])

I use pos_edge_label_index and neg_edge_label_index to store.

In the process of instantiating GNNExplainer, I wrote👇

model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

# Merge positive and negative edge labels and indices, and select a target edge for interpretation
edge_label_index = torch.cat([test_data.pos_edge_label_index, test_data.neg_edge_label_index], dim=1)
edge_label = torch.cat([test_data.pos_edge_label, test_data.neg_edge_label], dim=0)
edge_to_expln_index = edge_label_index[:, 666] 
target = edge_label[666].unsqueeze(dim=0).long() 

explainer = Explainer(
    model=model,
    explanation_type='phenomenon',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)

explanation = explainer(
    x=test_data.x, 
    edge_index=test_data.edge_index,
    target=target,
    edge_label_index=edge_to_expln_index,
)

available_explanations = explanation.available_explanations
print(f'Generated phenomenon explanations in {available_explanations}')

However, I encountered the following error message

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 23
     12 target = edge_label[666].unsqueeze(dim=0).long() ## 待解释边的真实标签
     14 explainer = Explainer(
     15     model=model,
     16     explanation_type='phenomenon',
   (...)
     20     model_config=model_config,
     21 )
---> 23 explanation = explainer(
     24     x=test_data.x, 
     25     edge_index=test_data.edge_index,
     26     target=target,
     27     edge_label_index=edge_to_expln_index,
     28 )
     30 available_explanations = explanation.available_explanations
     31 print(f'Generated phenomenon explanations in {available_explanations}')

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch_geometric\explain\explainer.py:205, in Explainer.__call__(self, x, edge_index, target, index, **kwargs)
    202 training = self.model.training
    203 self.model.eval()
--> 205 explanation = self.algorithm(
    206     self.model,
    207     x,
    208     edge_index,
    209     target=target,
    210     index=index,
    211     **kwargs,
    212 )
    214 self.model.train(training)
    216 # Add explainer objectives to the `Explanation` object:

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py:87, in GNNExplainer.forward(self, model, x, edge_index, target, index, **kwargs)
     83 if isinstance(x, dict) or isinstance(edge_index, dict):
     84     raise ValueError(f"Heterogeneous graphs not yet supported in "
     85                      f"'{self.__class__.__name__}'")
---> 87 self._train(model, x, edge_index, target=target, index=index, **kwargs)
     89 node_mask = self._post_process_mask(
     90     self.node_mask,
     91     self.hard_node_mask,
     92     apply_sigmoid=True,
     93 )
     94 edge_mask = self._post_process_mask(
     95     self.edge_mask,
     96     self.hard_edge_mask,
     97     apply_sigmoid=True,
     98 )

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py:132, in GNNExplainer._train(self, model, x, edge_index, target, index, **kwargs)
    129 optimizer.zero_grad()
    131 h = x if self.node_mask is None else x * self.node_mask.sigmoid()
--> 132 y_hat, y = model(h, edge_index, **kwargs), target
    134 if index is not None:
    135     y_hat, y = y_hat[index], y[index]

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

TypeError: forward() got an unexpected keyword argument 'edge_label_index'

Upvotes: 0

Views: 60

Answers (0)

Related Questions