Upendra01
Upendra01

Reputation: 414

take the output from a specific layer in pytorch

I have implemented an autoencoder in Pytorch and wish to extract the representations (output) from a specified encoding layer. This setup is similar to making predictions using sub-models that we used to have in Keras.

However, implementing something similar in Pytorch looks a bit challenging. I tried forward hooks as explained in How to get the output from a specific layer from a PyTorch model? and https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html but to no avail.

Could you help me getting outputs from a specific layer?

I have attached my code below:

class Autoencoder(torch.nn.Module):

    # Now defining the encoding and decoding layers.

    def __init__(self):
        super().__init__()   
        self.enc1 = torch.nn.Linear(in_features = 784, out_features = 256)
        self.enc2 = torch.nn.Linear(in_features = 256, out_features = 128)
        self.enc3 = torch.nn.Linear(in_features = 128, out_features = 64)
        self.enc4 = torch.nn.Linear(in_features = 64, out_features = 32)
        self.enc5 = torch.nn.Linear(in_features = 32, out_features = 16)
        self.dec1 = torch.nn.Linear(in_features = 16, out_features = 32)
        self.dec2 = torch.nn.Linear(in_features = 32, out_features = 64)
        self.dec3 = torch.nn.Linear(in_features = 64, out_features = 128)
        self.dec4 = torch.nn.Linear(in_features = 128, out_features = 256)
        self.dec5 = torch.nn.Linear(in_features = 256, out_features = 784)

    # Now defining the forward propagation step

    def forward(self,x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
    
        return x

autoencoder_network = Autoencoder()

I have to take the output from encoder layers marked enc1, enc2 .., enc5.

Upvotes: 0

Views: 1821

Answers (2)

abe
abe

Reputation: 987

You can define a global dictionary, like activations = {}, then in the forward function just assign values to it, like activations['enc1'] = x.clone().detach() and so on.

Upvotes: 0

Shai
Shai

Reputation: 114786

The simplest way is to explicitly return the activations you need:

    def forward(self,x):
        e1 = F.relu(self.enc1(x))
        e2 = F.relu(self.enc2(e1))
        e3 = F.relu(self.enc3(e2))
        e4 = F.relu(self.enc4(e3))
        e5 = F.relu(self.enc5(e4))
        x = F.relu(self.dec1(e5))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
    
        return x, e1, e2, e3, e4, e5

Upvotes: 1

Related Questions