tardis blue
tardis blue

Reputation: 9

Enhance model performance in text classification task

I tried to build a model for multi-label text classification task in chinese, but the performance of the model is not good enough (about 60% accuracy), and I come for help about how to enhance it.

I build a model based on a github project:

import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel


class BertMultiLabelCls(nn.Module):
    def __init__(self, hidden_size, class_num, dropout=0.1):
        super(BertMultiLabelCls, self).__init__()
        self.fc = nn.Linear(hidden_size, class_num)
        self.drop = nn.Dropout(dropout)
        self.bert = BertModel.from_pretrained("bert-base-chinese")

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask, token_type_ids)
        cls = self.drop(outputs[1])
        out = F.sigmoid(self.fc(cls))
        return out

My dataset is 2000 query-tag pair, having 13 tags and query about the questions asked by the audience in the live commerce. I splited the dataset by 3:1:1 corresponding to train/test/val. My tags are NOT balanced and no up sample/down sample strategy is used.

Loss and accuracy during the training process, where the horizontal axis stands for the epoch: loss&acc during the training process, where the horizontal axis stands for the epoch

The vaildation accuracy stopped increasing near 60%, and same result for my test dataset. I've tried various methods including adding more fully connection layer/adding residual connection, but the result remains the same.

Here are my training params if it helps:

lr = 2e-5
batch_size = 128
max_len = 64
hidden_size = 768
epochs = 30
optimizer = AdamW(model.parameters(), lr=lr)
criterion = nn.BCELoss() # loss function

Any suggestions about how I improve my model besides the datasets? because that what I am doing paralleling and I know how to improve it. But I'm really newbie about the network itself.

Upvotes: 0

Views: 137

Answers (1)

Venom0912
Venom0912

Reputation: 1

Given your imbalanced dataset, Focal Loss could be a valuable alternative to BCELoss. It focuses on hard-to-classify examples, reducing the influence of easy negatives that dominate the loss in imbalanced scenarios.

Convolutional neural networks (CNNs) can capture local patterns and n-gram features, complementing the global context captured by BERT. Consider adding a CNN layer before or after the BERT encoder.

We can also implement callback and Monitor validation accuracy and stop training early if it doesn't improve for a certain number of epochs. This prevents overfitting to the training data.

Upvotes: 0

Related Questions