Lydia
Lydia

Reputation: 95

How to add edge attributes for GrapheSage Link Prediction

I'm new to graph neural networks and I'm attempting to perform link prediction (binary classification), but I'm struggling to understand how to incorporate edge attributes into my SAGEConv layer. The documentation states that SAGEConv doesn't support edge attributes, but I'm unsure if there's a workaround to include edge attributes. Below is the code I'm practicing with, which I found on Medium. I would greatly appreciate any assistance you can provide.

import torch.nn.functional as F
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x
# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]
        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        self.movie_lin = torch.nn.Linear(20, hidden_channels)
        self.user_emb = torch.nn.Embedding(data["user"].num_nodes, hidden_channels)
        self.movie_emb = torch.nn.Embedding(data["movie"].num_nodes, hidden_channels)
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = Classifier()
    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
          "user": self.user_emb(data["user"].node_id),
          "movie": self.movie_lin(data["movie"].x) + self.movie_emb(data["movie"].node_id),
        } 
        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["user"],
            x_dict["movie"],
            data["user", "rates", "movie"].edge_label_index,
        )
        return pred
        
model = Model(hidden_channels=64)```

Upvotes: 2

Views: 433

Answers (1)

omegabuz
omegabuz

Reputation: 1

SAGEConv

You must ensure that the shape of edge_index you enter equals (2, |E|).

For instance, if the adjacency matrix for your particular graph in your particular problem type looks like this (|V|, |V|):

idx   0  1  2
0   [[0, 1, 1], 
1    [1, 0, 1], 
2    [1, 1, 0]]

It should be reshaped to look like this (2, |E|):

idx   0  1  2  3  4  5
0   [[0, 0, 1, 1, 2, 2],
1    [1, 2, 0, 2, 0, 1]]

If not, you need first alter the tensor's form as follows before making a forward pass:

class GNN(torch.nn.Module):
    def __init__(self, num_nodes, hidden_channels):
        super().__init__()
        self.num_nodes = num_nodes
        self.conv1 = SAGEConv(in_channels=2, out_channels=hidden_channels, aggr='mean')

    def forward(self, adj, coord):
        """
        Args:
            adj (torch.Tensor): Adjacency matrix of shape (B x V x V), where B is the batch size and V is the number of nodes.
            coord (torch.Tensor): Coordinate matrix of shape (B x V x 2)
        """
        # reshape your tensor
        edge_index = adj.reshape(-1, self.num_nodes).eq(1).nonzero().t().contiguous()  # B x V x V -> (B*V) x V -> 2 x |E|
        # send it to SAGEConv
        x = self.conv1(coord.view(-1, 2), edge_index)  # e.g. my node features shape is (|V|, F_in = 2)
        return x

Upvotes: 0

Related Questions