Taif Alomar
Taif Alomar

Reputation: 1

Passing BERT variant word embeddings into LSTM classifier performs better than passing them into a BiLSTM

I'm trying to fine-tune a bert variant called camel-msa and passing the generated word embeddings into an LSTM, and in another experiment into a BiLSTM.

I'm using the same functions and classes for both, the LSTM parts are commented (#LSTM) in the code below. my question is does LSTM perform better than BiLSTM when passing bert's word embeddings? As far as I know BiLSTM is better with understanding the context of text than LSTM, how do we explain that?

I'm passing Arabic patent dataset with long text but the max_len of tokens is 350. I'm using pytorch.

this is the class:

# Bert-Bilstm-Classfier class
class BertBilstmClassifier(nn.Module):

    def __init__(self, freeze_bert=False ):
      
        super(BertBilstmClassifier, self).__init__()
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in, H, D_out = 768, 50, 8

        # Instantiate BERT model
        self.bert = AutoModel.from_pretrained(model_name)

        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            #BiLSTM
            # nn.Linear(2*H, H),
            #LSTM
            nn.Linear(H, H),
            nn.ReLU(),
            nn.Linear(H, D_out)
        )
        # LSTM
        self.bilstm = nn.LSTM(D_in, H, batch_first = False, bidirectional=False)
        # BiLSTM
        # self.bilstm = nn.LSTM(D_in, H, batch_first = False, bidirectional=True)

        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        
    def forward(self, input_ids, attention_mask):
  
        

        # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
        a = outputs[0].tolist()
        #print("size out of bert:", np.array(a).shape)

        output =  self.bilstm(outputs[0])
        #print("output of BiLSTM ",len(list(outputs[0])))
         # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = output[0][:, 0, :]

        # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)

        return logits

this is the model initialization function:

def initialize_model(epochs=4):
    """Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
    """
    # Instantiate Bert Classifier
    bert_classifier = BertBilstmClassifier(freeze_bert=False)

    # Tell PyTorch to run the model on GPU
    bert_classifier.to(device)

    # Create the optimizer
    optimizer = AdamW(bert_classifier.parameters(),
                      lr=2e-5,    # Default learning rate
                      eps=1e-8    # Default epsilon value
                      )

    # Total number of training steps
    total_steps = len(train_dataloader) * epochs

    # Set up the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0, # Default value
                                                num_training_steps=total_steps)
    return bert_classifier, optimizer, scheduler    

THANK YOU!

Upvotes: 0

Views: 618

Answers (0)

Related Questions