stacksonoverflow
stacksonoverflow

Reputation: 111

How to use another pretrained BERT model with the ktrain text classifier?

How can we use a different pretrained model for the text classifier in the ktrain library? When using:

model = text.text_classifier('bert', (x_train, y_train) , preproc=preproc)

This uses the multilangual pretrained model

However, I want to try out a monolingual model as well. Namely the Dutch one: ''wietsedv/bert-base-dutch-cased', which is also used in other k-train implementations, for example.

However, when trying to use this command in the text classifier it does not work:

model = text.text_classifier('bert', (x_train, y_train) ,
> preproc=preproc, bert_model='wietsedv/bert-base-dutch-cased')

or

model = text.text_classifier('wietsedv/bert-base-dutch-cased', (x_train, y_train), preproc=preproc)

Does anyone how to do this? Thanks!

Upvotes: 1

Views: 2491

Answers (1)

blustax
blustax

Reputation: 489

There are two text classification APIs in ktrain. The first is the text_classifier API which can be used for a select number of both transformers and non-transformers models. The second is the Transformer API which can be used with any transformers model including the one you listed.

The latter is explained in detail in this tutorial notebook and this medium article.

For instance, you can replace MODEL_NAME with any model you want in the example below:

Example:

# load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)

# build, train, and validate model (Transformer is wrapper around transformers library)
import ktrain
from ktrain import text
MODEL_NAME = 'distilbert-base-uncased'  # replace this with model of choice
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)
learner.fit_onecycle(5e-5, 4)
learner.validate(class_names=t.get_classes()) # class_names must be string values

# Output from learner.validate()
#                        precision    recall  f1-score   support
#
#           alt.atheism       0.92      0.93      0.93       319
#         comp.graphics       0.97      0.97      0.97       389
#               sci.med       0.97      0.95      0.96       396
#soc.religion.christian       0.96      0.96      0.96       398
#
#              accuracy                           0.96      1502
#             macro avg       0.95      0.96      0.95      1502
#          weighted avg       0.96      0.96      0.96      1502

Upvotes: 5

Related Questions