Fardin Ahsan
Fardin Ahsan

Reputation: 53

How to extract graph node embeddings from a Pytorch-Geometric GAT model?

Dataset Strucute: Temporal directed graph; Nodes have features; Edges don't have features; Nodes are labelled. Using the Elliptic Dataset

Task: Classify nodes/ Predict node labels.

Data Structure: 2 .csv files of nodes and edges.

I want to train various Graph Neural Networks on the data and extract node embeddings from the networks. I know that is possible because the authors of the Elliptic dataset extracted node embeddings from a GCN.

Below is the code for the GAT I am using.

class GAT(torch.nn.Module):
  """Graph Attention Network"""
  def __init__(self, dim_in, dim_h, dim_out, heads=24):
    super().__init__()
    self.gat1 = GATv2Conv(dim_in, dim_h, heads=heads)
    self.gat2 = GATv2Conv(dim_h*heads, dim_out, heads=1)
    self.optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.25,
                                      weight_decay=5e-4)

  def forward(self, x, edge_index):
    h = F.dropout(x, p=0.5, training=self.training)
    h = self.gat1(x, edge_index)
    h = F.elu(h)
    h = F.dropout(h, p=0.5, training=self.training)
    h = self.gat2(h, edge_index)
    return h, F.log_softmax(h, dim=1)

This function returns a trained model

def train(model, data , epochs = 200):
    """Train a GNN model and return the trained model."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = model.optimizer

    model = model.to(device)

    model.train()
    for epoch in range(epochs+1):
        # Training
        optimizer.zero_grad()
        _, out = model(data.x.to(device), data.edge_index.to(device))
        loss = criterion(out[data.train_mask].to(device), data.y[data.train_mask].to(device))
        loss.backward()
        optimizer.step()

        # Print metrics every 10 epochs
        if(epoch % 10 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f}')
          
    return model

What modifications do I need to make to the code to extract the node embeddings?

Upvotes: 2

Views: 981

Answers (1)

Eli
Eli

Reputation: 1

You can write a method, something like this using a subgraph loader for a large graph:

def representation(self,x_all):
for i, conv in enumerate(self.convs):
         xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device)
                x = conv(x, batch.edge_index.to(device))
                if i < len(self.convs) - 1:
                    x = F.elu_(x)
                xs.append(x[:batch.batch_size].cpu())
                pbar.update(batch.batch_size)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all

from here.

You can also use get_embeddings from the pytorch geometric utils if its not a large graph.

Upvotes: 0

Related Questions