Hrushi
Hrushi

Reputation: 509

Get some layers in a pytorch model that is not defined by nn.Sequential

I have a network defined below.

class model_dnn_2(nn.Module):
    def __init__(self):
        super(model_dnn_2, self).__init__()
        self.flatten = Flatten()
        self.fc1 = nn.Linear(784, 200)
        self.fc2 = nn.Linear(200, 100)
        self.fc3 = nn.Linear(100, 100)
        self.fc4 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)

I would like to take the last two layers along with the relu functions. Using children method I get the following

>>> new_model = nn.Sequential(*list(model.children())[-2:])
>>> new_model
Sequential(
  (0): Linear(in_features=100, out_features=100, bias=True)
  (1): Linear(in_features=100, out_features=10, bias=True)
)

But I would like to have the Relu function present in between the layers-just like the original model, i.e the new model should be like:

>>> new_model
Sequential(
  (0): Linear(in_features=100, out_features=100, bias=True)
  (1): Relu()
  (2): Linear(in_features=100, out_features=10, bias=True)
)

I think the children method of the model is using the class initialization to create the model and thus the problem arises.

How can I obtain the model?

Upvotes: 2

Views: 1325

Answers (1)

Shai
Shai

Reputation: 114936

The way you implemented your model, the ReLU activations are not layers, but rather functions. When listing sub-layers (aka "children") of your module you do not see the ReLUs.

You can change your implementation:

class model_dnn_2(nn.Module):
    def __init__(self):
        super(model_dnn_2, self).__init__()
        self.layers = nn.Sequential(
          nn.Flatten(),
          nn.Linear(784, 200),
          nn.ReLU(),  # now you are using a ReLU _layer_
          nn.Linear(200, 100),
          nn.ReLU(),  # this is a different ReLU _layer_
          nn.Linear(100, 100),
          nn.ReLU(),
          nn.Linear(100, 10)
        )

    def forward(self, x):
      y = self.layers(x)
      return y

More on the difference between layers and functions can be found here.

Upvotes: 2

Related Questions