Reputation: 11
When I use link prediction, I encounter an unknown parameter problem.
# Define GraphSAGE model
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout_rate=0.5):
super(SAGE, self).__init__()
self.initial_bn = BatchNorm(in_channels)
self.convs = torch.nn.ModuleList()
self.bns = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
self.bns.append(BatchNorm(hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.bns.append(BatchNorm(hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
self.dropout = dropout_rate
def reset_parameters(self):
self.initial_bn.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x, edge_index):
x = self.initial_bn(x)
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = self.bns[i](x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
def decode(self, z, pos_edge_index, neg_edge_index):
edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
logits = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
return torch.sigmoid(logits)
My dataset is like this
Data(x=[9730, 1545], edge_index=[2, 242828], train_mask=[9730], val_mask=[9730], test_mask=[9730], pos_edge_label=[7336], pos_edge_label_index=[2, 7336], neg_edge_label=[7336], neg_edge_label_index=[2, 7336])
It is slightly different from the official examples for cora dataset
Data(x=[2708, 1433], edge_index=[2, 9774], y=[2708], train_mask=[2708], test_mask=[2708], num_classes=7, edge_label=[1084], edge_label_index=[2, 1084])
I use pos_edge_label_index and neg_edge_label_index to store.
In the process of instantiating GNNExplainer, I wrote👇
model_config = ModelConfig(
mode='binary_classification',
task_level='edge',
return_type='raw',
)
# Merge positive and negative edge labels and indices, and select a target edge for interpretation
edge_label_index = torch.cat([test_data.pos_edge_label_index, test_data.neg_edge_label_index], dim=1)
edge_label = torch.cat([test_data.pos_edge_label, test_data.neg_edge_label], dim=0)
edge_to_expln_index = edge_label_index[:, 666]
target = edge_label[666].unsqueeze(dim=0).long()
explainer = Explainer(
model=model,
explanation_type='phenomenon',
algorithm=GNNExplainer(epochs=200),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
)
explanation = explainer(
x=test_data.x,
edge_index=test_data.edge_index,
target=target,
edge_label_index=edge_to_expln_index,
)
available_explanations = explanation.available_explanations
print(f'Generated phenomenon explanations in {available_explanations}')
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[7], line 23
12 target = edge_label[666].unsqueeze(dim=0).long() ## å¾…è§£é‡Šè¾¹çš„çœŸå®žæ ‡ç¾
14 explainer = Explainer(
15 model=model,
16 explanation_type='phenomenon',
(...)
20 model_config=model_config,
21 )
---> 23 explanation = explainer(
24 x=test_data.x,
25 edge_index=test_data.edge_index,
26 target=target,
27 edge_label_index=edge_to_expln_index,
28 )
30 available_explanations = explanation.available_explanations
31 print(f'Generated phenomenon explanations in {available_explanations}')
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch_geometric\explain\explainer.py:205, in Explainer.__call__(self, x, edge_index, target, index, **kwargs)
202 training = self.model.training
203 self.model.eval()
--> 205 explanation = self.algorithm(
206 self.model,
207 x,
208 edge_index,
209 target=target,
210 index=index,
211 **kwargs,
212 )
214 self.model.train(training)
216 # Add explainer objectives to the `Explanation` object:
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py:87, in GNNExplainer.forward(self, model, x, edge_index, target, index, **kwargs)
83 if isinstance(x, dict) or isinstance(edge_index, dict):
84 raise ValueError(f"Heterogeneous graphs not yet supported in "
85 f"'{self.__class__.__name__}'")
---> 87 self._train(model, x, edge_index, target=target, index=index, **kwargs)
89 node_mask = self._post_process_mask(
90 self.node_mask,
91 self.hard_node_mask,
92 apply_sigmoid=True,
93 )
94 edge_mask = self._post_process_mask(
95 self.edge_mask,
96 self.hard_edge_mask,
97 apply_sigmoid=True,
98 )
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py:132, in GNNExplainer._train(self, model, x, edge_index, target, index, **kwargs)
129 optimizer.zero_grad()
131 h = x if self.node_mask is None else x * self.node_mask.sigmoid()
--> 132 y_hat, y = model(h, edge_index, **kwargs), target
134 if index is not None:
135 y_hat, y = y_hat[index], y[index]
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File C:\Anaconda3\envs\pytorch_siki\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
TypeError: forward() got an unexpected keyword argument 'edge_label_index'
Upvotes: 0
Views: 60