Reputation: 71
I am trying to use graph attention network (GAT) module in torch_geometric
but keep running into AssertionError: Static graphs not supported in 'GATConv'
with the following code.
class GraphConv_sum(nn.Module):
def __init__(self, in_ch, out_ch, num_layers, block, adj):
super(GraphConv_sum, self).__init__()
adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
self.g_conv = nn.ModuleList()
self.act = nn.LeakyReLU()
for n in range(num_layers):
if n == 0:
self.g_conv.append(block(in_ch, 16))
elif n > 0 and n < num_layers - 1:
self.g_conv.append(block(16, 16))
else:
self.g_conv.append(block(16, out_ch))
def forward(self, x):
for layer in self.g_conv:
x = layer(x=x, edge_index=self.edge_index)
x = self.act(x)
print(x.shape)
return x[:, 0, :]
When I replace block
with GATConv
followed by a standard training loop, this error happens (other conv layers such as GCNConv
or SAGEConv
didn't have any problems). I checked the documentation and made sure the input shape was correct (same for other conv layers).
In the source code, there is this assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
part in the forward
method but apparently the batch dimension will come into play in the forward pass and x.dim()
would be 3. The input shape with batch dimension is [1024, 6, 200]. However, if I manually change the assert condition to x.dim() == 3
the same error will still be raised as if the condition is not satisfied. I only have a high-level grasp of GAT so there might be something I am missing. Anyways, I have a few questions from this
I would appreciate any insights and help!! Thanks!
Upvotes: 2
Views: 4299
Reputation: 71
It turns out due to the attention weight calculation, GATConv doesn't support multiple feature matrices and single edge_index. More info: https://github.com/pyg-team/pytorch_geometric/issues/2844
Upvotes: 2