Yixuan Sun
Yixuan Sun

Reputation: 71

AssertionError in torch_geometric.nn.GATConv

I am trying to use graph attention network (GAT) module in torch_geometric but keep running into AssertionError: Static graphs not supported in 'GATConv' with the following code.

class GraphConv_sum(nn.Module):
    def __init__(self, in_ch, out_ch, num_layers, block, adj):
        super(GraphConv_sum, self).__init__()
        adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
        self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
        self.g_conv = nn.ModuleList()
        
        self.act = nn.LeakyReLU()

        for n in range(num_layers):
            if n == 0:
                self.g_conv.append(block(in_ch, 16))
            elif n > 0 and n < num_layers - 1:
                self.g_conv.append(block(16, 16))
            else:
                self.g_conv.append(block(16, out_ch))

    def forward(self, x):
        for layer in self.g_conv:
            x = layer(x=x, edge_index=self.edge_index)
            x = self.act(x)
            print(x.shape)
        return x[:, 0, :]

When I replace block with GATConv followed by a standard training loop, this error happens (other conv layers such as GCNConv or SAGEConv didn't have any problems). I checked the documentation and made sure the input shape was correct (same for other conv layers).

In the source code, there is this assert x.dim() == 2, "Static graphs not supported in 'GATConv'" part in the forward method but apparently the batch dimension will come into play in the forward pass and x.dim() would be 3. The input shape with batch dimension is [1024, 6, 200]. However, if I manually change the assert condition to x.dim() == 3 the same error will still be raised as if the condition is not satisfied. I only have a high-level grasp of GAT so there might be something I am missing. Anyways, I have a few questions from this

I would appreciate any insights and help!! Thanks!

Upvotes: 2

Views: 4299

Answers (1)

Yixuan Sun
Yixuan Sun

Reputation: 71

It turns out due to the attention weight calculation, GATConv doesn't support multiple feature matrices and single edge_index. More info: https://github.com/pyg-team/pytorch_geometric/issues/2844

Upvotes: 2

Related Questions