Reputation: 1
I'm trying to fine-tune a bert variant called camel-msa and passing the generated word embeddings into an LSTM, and in another experiment into a BiLSTM.
I'm using the same functions and classes for both, the LSTM parts are commented (#LSTM) in the code below. my question is does LSTM perform better than BiLSTM when passing bert's word embeddings? As far as I know BiLSTM is better with understanding the context of text than LSTM, how do we explain that?
I'm passing Arabic patent dataset with long text but the max_len of tokens is 350. I'm using pytorch.
this is the class:
# Bert-Bilstm-Classfier class
class BertBilstmClassifier(nn.Module):
def __init__(self, freeze_bert=False ):
super(BertBilstmClassifier, self).__init__()
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
D_in, H, D_out = 768, 50, 8
# Instantiate BERT model
self.bert = AutoModel.from_pretrained(model_name)
# Instantiate an one-layer feed-forward classifier
self.classifier = nn.Sequential(
#BiLSTM
# nn.Linear(2*H, H),
#LSTM
nn.Linear(H, H),
nn.ReLU(),
nn.Linear(H, D_out)
)
# LSTM
self.bilstm = nn.LSTM(D_in, H, batch_first = False, bidirectional=False)
# BiLSTM
# self.bilstm = nn.LSTM(D_in, H, batch_first = False, bidirectional=True)
# Freeze the BERT model
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
# Feed input to BERT
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
a = outputs[0].tolist()
#print("size out of bert:", np.array(a).shape)
output = self.bilstm(outputs[0])
#print("output of BiLSTM ",len(list(outputs[0])))
# Extract the last hidden state of the token `[CLS]` for classification task
last_hidden_state_cls = output[0][:, 0, :]
# Feed input to classifier to compute logits
logits = self.classifier(last_hidden_state_cls)
return logits
this is the model initialization function:
def initialize_model(epochs=4):
"""Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
"""
# Instantiate Bert Classifier
bert_classifier = BertBilstmClassifier(freeze_bert=False)
# Tell PyTorch to run the model on GPU
bert_classifier.to(device)
# Create the optimizer
optimizer = AdamW(bert_classifier.parameters(),
lr=2e-5, # Default learning rate
eps=1e-8 # Default epsilon value
)
# Total number of training steps
total_steps = len(train_dataloader) * epochs
# Set up the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=0, # Default value
num_training_steps=total_steps)
return bert_classifier, optimizer, scheduler
THANK YOU!
Upvotes: 0
Views: 618