Reputation: 8451
I have trained a fairly simple Transformer model with 6 TransformerEncoder layers:
class LitModel(pl.LightningModule):
def __init__(self,
num_tokens: int,
dim_model: int = 96,
dim_h: int = 128,
n_head: int = 1,
dropout: float = 0.1,
activation: str = 'relu',
num_layers: int = 2,
lr: float=1e-3):
"""
:param num_tokens:
:param dim_model:
:param dim_h:
:param n_head:
:param dropout:
:param activation:
:param num_layers:
"""
super().__init__()
self.lr = lr
self.embed = torch.nn.Embedding(num_embeddings=num_tokens,
embedding_dim=dim_model)
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=dim_model,
nhead=n_head,
dim_feedforward=dim_h,
dropout=dropout,
activation=activation,
batch_first=True)
self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer,
num_layers=num_layers)
self.linear = torch.nn.Linear(in_features=dim_model, out_features=num_tokens)
def forward(self, indices, mask):
x = self.embed(indices)
x = self.encoder(x, src_key_padding_mask=mask)
return x
def training_step(self, batch, batch_idx):
x = batch['src']
y = batch['label']
mask = batch['mask']
x = self.embed(x)
x = self.encoder(x, src_key_padding_mask=mask)
x = self.linear(x)
loss = F.cross_entropy(input=x.transpose(1, 2),
target=y,
ignore_index=0)
self.log('train_loss', loss)
return loss
After training the model to predict [MASK] tokens (exactly like BERT), I would like to be able to extract the outputs from the lower layers, specifically, the second to last TransformerEncoderLayer
, which may give a better vector encoding than the final layer (according to the original BERT paper). I'm not sure how to go about doing this.
Upvotes: 1
Views: 3887
Reputation: 13601
Just in case it is not clear from the comments, you can do that by registering a forward hook:
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
# instantiate the model
model = LitModel(...)
# register the forward hook
model.encoder.layers[-2].register_forward_hook(get_activation('encoder_penultimate_layer'))
# pass some data through the model
output = model(x)
# this is what you're looking for
activation['encoder_penultimate_layer']
Upvotes: 4