Evan Zamir
Evan Zamir

Reputation: 8451

How to get output from intermediate encoder layers in PyTorch Transformer?

I have trained a fairly simple Transformer model with 6 TransformerEncoder layers:

class LitModel(pl.LightningModule):
    def __init__(self,
                 num_tokens: int,
                 dim_model: int = 96,
                 dim_h: int = 128,
                 n_head: int = 1,
                 dropout: float = 0.1,
                 activation: str = 'relu',
                 num_layers: int = 2,
                 lr: float=1e-3):
        """

        :param num_tokens:
        :param dim_model:
        :param dim_h:
        :param n_head:
        :param dropout:
        :param activation:
        :param num_layers:
        """
        super().__init__()
        self.lr = lr
        self.embed = torch.nn.Embedding(num_embeddings=num_tokens,
                                        embedding_dim=dim_model)
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=dim_model,
                                                         nhead=n_head,
                                                         dim_feedforward=dim_h,
                                                         dropout=dropout,
                                                         activation=activation,
                                                         batch_first=True)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer,
                                                   num_layers=num_layers)
        self.linear = torch.nn.Linear(in_features=dim_model, out_features=num_tokens)

    def forward(self, indices, mask):
        x = self.embed(indices)
        x = self.encoder(x, src_key_padding_mask=mask)
        return x

    def training_step(self, batch, batch_idx):
        x = batch['src']
        y = batch['label']
        mask = batch['mask']

        x = self.embed(x)
        x = self.encoder(x, src_key_padding_mask=mask)
        x = self.linear(x)

        loss = F.cross_entropy(input=x.transpose(1, 2),
                               target=y,
                               ignore_index=0)
        self.log('train_loss', loss)
        return loss

After training the model to predict [MASK] tokens (exactly like BERT), I would like to be able to extract the outputs from the lower layers, specifically, the second to last TransformerEncoderLayer, which may give a better vector encoding than the final layer (according to the original BERT paper). I'm not sure how to go about doing this.

Upvotes: 1

Views: 3887

Answers (1)

Berriel
Berriel

Reputation: 13601

Just in case it is not clear from the comments, you can do that by registering a forward hook:

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

# instantiate the model
model = LitModel(...)

# register the forward hook
model.encoder.layers[-2].register_forward_hook(get_activation('encoder_penultimate_layer'))

# pass some data through the model
output = model(x)

# this is what you're looking for
activation['encoder_penultimate_layer']

Upvotes: 4

Related Questions