Penguin
Penguin

Reputation: 2441

Efficient way to get "neuron-edge-neuron" values in a neural network

I'm working on a visual networks project where I'm trying to plot several node-edge-node values in an interactive graph.

I have several neural networks (this is one example):

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x):
        x1 = self.fc1(x)
        x = torch.relu(x1)
        x2 = self.fc2(x)
        x = torch.relu(x2)
        x3 = self.fc3(x)
        return x3, x2, x1

net = Model()

How can I get the node-edge-node (neuron-edge-neuron) values in the network in an efficient way? Some of these networks have a large number of parameters. Note that for the first layer it will be input-edge-neuron rather than neuron-edge-neuron.

I tried saving each node values after the fc layers (ie x1,x2,x3) so I won't have to recompute them, but I'm not sure how to do the edges and match them to their corresponding neurons in an efficient way.

The output I'm looking for is a list of lists of node-edge-node values. Though it can also be a tensor of tensors if it's easier. For example, in the above network from the first layer I will have 2 triples (1x2), from the 2nd layer I will have 6 of them (2x3), and in the last layer I will have 3 triples (3x1). The issue is matching nodes (ie neurons) values (one from layer n-1 and one from layer n) with the corresponding edges in an efficient way.

Upvotes: 1

Views: 250

Answers (1)

ayandas
ayandas

Reputation: 2268

Confession: Let's start by saying that I modified your code a bit to make it convenient. You can do everything in the form it originally was. I also changed the specific number of neurons just for playing around (I am sure you can revert them back).

I created a summary object (returned by .forward() function) that contains entire execution trace of the network, i.e. (input, weight, output) tuples for *every layer.

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 5)
        self.fc2 = nn.Linear(5, 7)
        self.fc3 = nn.Linear(7, 2)

    def forward(self, x):
        summary = []
        running_x = x
        for layer in self.children():
            out = layer(running_x)
            # triplet of (input, weight, output) for each layer
            summary.append((running_x, layer.weight, out))
            running_x = out

        return summary

model = Model()
batch_size = 32
X = torch.rand(batch_size, 3)
summary = model(X)

The core logic is only this much

for L in summary: # iterate over the (ip, weight, out) tuple for each layer
    ip, weight, out = L # unpack them
    
    ip = ip[:, :, None, None].repeat(1, 1, out.shape[-1], 1)
    weight = weight.T[None, :, :, None].repeat(batch_size, 1, 1, 1)
    out = out[:, None, :, None].repeat(1, ip.shape[1], 1, 1)
    triplets = torch.cat([ip, weight, out], -1)

So the triplets variable (one for each layer) is all you are looking for. It has a size

(batch_size, layer_in_dim, layer_out_dim, 3)

Let's see specifically the triplets for first layer.

>> triplets.shape
(32, 3, 5, 3)

E.g., given a sample index b = 12, input neuron index i = 1 and output neuron index j = 3, you have exactly node-edge-node tuples

>> triplets[b][i][j]
tensor([0.7080, 0.3442, 0.7344], ...)

Verify: Let's manually verify the correctness.

The 12th sample's 1st dimension is

# Its the first layer we are looking, so input comes from user
>> X[12][1]
tensor(0.7080) 

CHECK.

The connecting weight between 1st input neuron and 3rd output neuron for first layer

>> model.fc1.weight.T[1][3] # weight matrix is transposed, so had to do .T
tensor(0.3442, ...)

CHECK.

The output of 3rd neuron for 12th sample can be retrieved from its activation tensor

>> _, _, out = summary[0] # first layer's output tensor
>> out[12][3]
tensor(0.7344, ...)

ALSO CHECK.


I hope that's what you wanted. If anymore info/changed needed, feel free to comment. I don't think it can get any more efficient that that.

Upvotes: 3

Related Questions