Reputation: 9
I tried to build a model for multi-label text classification task in chinese, but the performance of the model is not good enough (about 60% accuracy), and I come for help about how to enhance it.
I build a model based on a github project:
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
class BertMultiLabelCls(nn.Module):
def __init__(self, hidden_size, class_num, dropout=0.1):
super(BertMultiLabelCls, self).__init__()
self.fc = nn.Linear(hidden_size, class_num)
self.drop = nn.Dropout(dropout)
self.bert = BertModel.from_pretrained("bert-base-chinese")
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.bert(input_ids, attention_mask, token_type_ids)
cls = self.drop(outputs[1])
out = F.sigmoid(self.fc(cls))
return out
My dataset is 2000 query-tag pair, having 13 tags and query about the questions asked by the audience in the live commerce. I splited the dataset by 3:1:1 corresponding to train/test/val. My tags are NOT balanced and no up sample/down sample strategy is used.
Loss and accuracy during the training process, where the horizontal axis stands for the epoch:
The vaildation accuracy stopped increasing near 60%, and same result for my test dataset. I've tried various methods including adding more fully connection layer/adding residual connection, but the result remains the same.
Here are my training params if it helps:
lr = 2e-5
batch_size = 128
max_len = 64
hidden_size = 768
epochs = 30
optimizer = AdamW(model.parameters(), lr=lr)
criterion = nn.BCELoss() # loss function
Any suggestions about how I improve my model besides the datasets? because that what I am doing paralleling and I know how to improve it. But I'm really newbie about the network itself.
Upvotes: 0
Views: 137
Reputation: 1
Given your imbalanced dataset, Focal Loss could be a valuable alternative to BCELoss. It focuses on hard-to-classify examples, reducing the influence of easy negatives that dominate the loss in imbalanced scenarios.
Convolutional neural networks (CNNs) can capture local patterns and n-gram features, complementing the global context captured by BERT. Consider adding a CNN layer before or after the BERT encoder.
We can also implement callback and Monitor validation accuracy and stop training early if it doesn't improve for a certain number of epochs. This prevents overfitting to the training data.
Upvotes: 0