Reputation: 11
I would like to adapt the example DGL GATLayer such that instead of learning node representations, the network can learn the edge weights. That is, I want to to build a network that takes a set of node features as input and outputs the edges. The labels will be a set of "truth edges", which represent which nodes come from a common source, such that I can learn to cluster unseen data in the same way.
I am using as a starting point the code from the following DGL example:
https://www.dgl.ai/blog/2019/02/17/gat.html
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z' : edges.src['z'], 'e' : edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
# Be aware that the input dimension is hidden_dim*num_heads since
# multiple head outputs are concatenated together. Also, only
# one attention head in the output layer.
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
I had hoped I could adapt this to simply return the edges instead of the nodes, eg by replacing the line
return self.g.ndata.pop('h')
with
return self.e.ndata.pop('e')
But it seems it is not this simple. I managed to get something to run, but the loss jumped around all over the place and no learning occurred.
I am new to graph networks, though not to deep learning in general. Is what I am trying to do a reasonable thing? Am I missing something crucial in my understanding of how this works? I have been unable to find any easy to understand examples of graph networks where the edges themselves are the learning objective, so I'm a bit muddled at the moment. I appreciate any help that anyone can give!
Upvotes: 0
Views: 1380
Reputation: 1
I'm not completely sure because it depends on your input but self.g is most likely a DGL graph therefor in the code they access ndata which stands for node data, if you want to access the graphs edge data you would access edata. Therefor you should write return self.g.edata... even though I'm not sure which attributes of the edges you're trying to access which will change the pop(whatever you're trying to access)
Upvotes: 0