Reputation: 2441
Say I have the following model:
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 5)
self.fc2 = nn.Linear(5, 10)
self.fc3 = nn.Linear(10, 1)
def forward(self, x):
x = self.fc1(x)
print(x)
x = torch.relu(x)
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Model()
opt = optim.Adam(net.parameters())
features = torch.rand((3,1)) #3 inputs, each of 1D
I can print the value of my neurons (only the first layer here) with print(x)
:
net(features)
>>>tensor([[ 0.6703, 0.4484, -0.8529, 1.3119, 0.6741],
[ 0.9112, 0.6496, -1.2960, 1.8264, 0.4547],
[ 0.7483, 0.5135, -0.9963, 1.4785, 0.6031]],
grad_fn=<AddmmBackward>)
tensor([[0.0144],
[0.0575],
[0.0284]], grad_fn=<AddmmBackward>)
How can I add a "feature" to each neuron that is a string with a name? e.g.
print(x)
>>> tensor([[ [0.6703, 'neuron_1'], [0.4484, 'neuron_2'], [-0.8529, 'neuron_3'], 1.3119, 0.6741],... etc.
I'm not sure if I'll need to change the neuron
class. I believe in the forward method I will then need to only take the first element of each neurons tensor: neuron_tensor = [neuron_value, neuron_name]
Update 1: from @Aditya Singh Rathore comment it sounds like it might not be possible to have a string and a value in the same tensor. Is it possible then to have a value instead of a string to represent the neurons?
From before neuron_tensor = [neuron_value, neuron_name]
where neuron_name
is a string.
Is this possible instead? : neuron_tensor = [neuron_value, neuron_name]
where neuron_name
is just a value (e.g 1
for neuron 1, 2
for neuron 2)
Upvotes: 0
Views: 148
Reputation: 4181
What you want is possible, although I'm not sure what you exactly want to do with this. Basically you seem to want to include the "index" in the intermediate tensor for some purpose, and again discard it when passing to the next layer.
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 5)
self.fc2 = nn.Linear(5, 10)
self.fc3 = nn.Linear(10, 1)
def forward(self, x):
x = self.fc1(x)
idx_tensor = (torch.arange(x.shape[1]) + 1).unsqueeze(0).repeat_interleave(repeats=x.shape[0], dim=0)
x = torch.cat([x.unsqueeze(2), idx_tensor.unsqueeze(2)], dim=2)
print(x)
x = torch.relu(x[:, :, 0])
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
Running this gives:
net = Model()
opt = optim.Adam(net.parameters())
features = torch.rand((3,1))
net(features)
>>>tensor([[[ 0.0817, 1.0000],
[ 0.8084, 2.0000],
[ 1.6118, 3.0000],
[ 0.8658, 4.0000],
[-0.1583, 5.0000]],
[[ 0.2881, 1.0000],
[ 0.6946, 2.0000],
[ 1.3760, 3.0000],
[ 0.6098, 4.0000],
[-0.1240, 5.0000]],
[[ 0.1919, 1.0000],
[ 0.7476, 2.0000],
[ 1.4859, 3.0000],
[ 0.7291, 4.0000],
[-0.1400, 5.0000]]], grad_fn=<CatBackward>)
tensor([[-0.2841],
[-0.2191],
[-0.2495]], grad_fn=<AddmmBackward>)
Note that a typical tensor cannot be both integer and float at the same time, so the 1, 2, 3 will be stored as float 1.000.., 2.000... etc.
I suggest if your purpose is something like a fancy printing, then maybe look into torch's hook functions? For example, you can do something like:
import pandas as pd
def fc_hook_fn(module, input, output):
print("\n" + "#" * 60)
print(f"In layer {module}")
print("#" * 60 + "\n")
cols = [f"Neuron-{i + 1}" for i in range(output.shape[1])]
idx = [f"Input-{i + 1}" for i in range(output.shape[0])]
neuron_activations = pd.DataFrame(output.detach().numpy(), columns=cols, index=idx)
print(neuron_activations)
net.fc.register_forward_hook(fc_hook_fn)
Now each time something passes through fc1, the function above will be triggered. You don't need to put your print(x) in the forward method.
############################################################
In layer Linear(in_features=1, out_features=5, bias=True)
############################################################
Neuron-1 Neuron-2 Neuron-3 Neuron-4 Neuron-5
Input-1 -0.948735 -0.901034 -0.290353 -0.082616 -0.405337
Input-2 -0.725904 -0.801648 -0.302922 -0.045514 -0.580485
Input-3 -0.829738 -0.847960 -0.297065 -0.062802 -0.498870
Upvotes: 1