Reputation: 23
I finetuned two separate bert model (bert-base-uncased) on sentiment analysis and pos tagging tasks. Now, I want to feed the output of the pos tagger (batch, seqlength, hiddensize) as input to the sentiment model.The original bert-base-uncased model is in 'bertModel/' folder which contains 'model.bin' and 'config.json'. Here is my code:
class DeepSequentialModel(nn.Module):
def __init__(self, sentiment_model_file, postag_model_file, device):
super(DeepSequentialModel, self).__init__()
self.sentiment_model = SentimentModel().to(device)
self.sentiment_model.load_state_dict(torch.load(sentiment_model_file, map_location=device))
self.postag_model = PosTagModel().to(device)
self.postag_model.load_state_dict(torch.load(postag_model_file, map_location=device))
self.classificationLayer = nn.Linear(768, 1)
def forward(self, seq, attn_masks):
postag_context = self.postag_model(seq, attn_masks)
sent_context = self.sentiment_model(postag_context, attn_masks)
logits = self.classificationLayer(sent_context)
return logits
class PosTagModel(nn.Module):
def __init__(self,):
super(PosTagModel, self).__init__()
self.bert_layer = BertModel.from_pretrained('bertModel/')
self.classificationLayer = nn.Linear(768, 43)
def forward(self, seq, attn_masks):
cont_reps, _ = self.bert_layer(seq, attention_mask=attn_masks)
return cont_reps
class SentimentModel(nn.Module):
def __init__(self,):
super(SentimentModel, self).__init__()
self.bert_layer = BertModel.from_pretrained('bertModel/')
self.cls_layer = nn.Linear(768, 1)
def forward(self, input, attn_masks):
cont_reps, _ = self.bert_layer(encoder_hidden_states=input, encoder_attention_mask=attn_masks)
cls_rep = cont_reps[:, 0]
return cls_rep
But I get the below error. I appreciate it if someone could help me. Thanks!
cont_reps, _ = self.bert_layer(encoder_hidden_states=input, encoder_attention_mask=attn_masks)
result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'
Upvotes: 2
Views: 845
Reputation: 11430
To formulate this as an answer, too, and keep it properly visible for future visitors, the forward()
call of transformers does not support these arguments in version 2.1.1, or any earlier version, for that matter. note that the link in my comment is in fact pointing to a different forward function, but otherwise the point still holds.
Passing encoder_hidden_states
to forward()
was first possible in version 2.2.0.
Upvotes: 1