Alaa Grable
Alaa Grable

Reputation: 101

How can i get all outputs of the last transformer encoder in bert pretrained model and not just the cls token output?

I'm using pytorch and this is the model from huggingface transformers link:

from transformers import BertTokenizerFast, BertForSequenceClassification
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                     num_labels=int(data['class'].nunique()),
                                                     output_attentions=False,
                                                     output_hidden_states=False)

and in the forward function I'm building, I'm calling x1, x2 = self.bert(sent_id, attention_mask=mask) Now, as far as I know, x2 is the cls output(which is the output of the first transformer encoder) but yet again, I don't think I understand the output of the model. but I want the output of all the 12 last transformer encoders. How can I do that in pytorch ?

Upvotes: 3

Views: 4025

Answers (2)

Eric
Eric

Reputation: 401

detailed in the doc: https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel.

from transformers import BertModel, BertConfig

config = BertConfig.from_pretrained("xxx", output_hidden_states=True)
model = BertModel.from_pretrained("xxx", config=config)

outputs = model(inputs)
print(len(outputs))  # 3

hidden_states = outputs[2]
print(len(hidden_states))  # 13

embedding_output = hidden_states[0]
attention_hidden_states = hidden_states[1:]

Upvotes: 0

Ashwin Geet D'Sa
Ashwin Geet D'Sa

Reputation: 7379

Ideally, if you want to look into the outputs of all the layer, you should use BertModel and not BertForSequenceClassification. Because, BertForSequenceClassification is inherited from BertModel and adds a linear layer on top of the BERT model.

from transformers import BertModel
my_bert_model = BertModel.from_pretrained("bert-base-uncased")

### Add your code to map the model to device, data to device, and obtain input_ids and mask

sequence_output, pooled_output = my_bert_model(ids, attention_mask=mask)

# sequence_output has the following shape: (batch_size, sequence_length, 768), which contains output for all tokens in the last layer of the BERT model.

sequence_output contains output for all tokens in the last layer of the BERT model.

In order to obtain the outputs of all the transformer encoder layers, you can use the following:

my_bert_model = BertModel.from_pretrained("bert-base-uncased")
sequence_output, pooled_output, all_layer_output = model(ids, attention_mask=mask, output_hidden_states=True)

all_layer_output is a output tuple containing the outputs embeddings layer + outputs of all the layer. Each element in the tuple will have a shape (batch_size, sequence_length, 768)

Hence, to get the sequence of outputs at layer-5, you can use all_layer_output[5]. As, all_layer_output[0] contains outputs of the embeddings.

Upvotes: 3

Related Questions