Reputation: 111
How can we use a different pretrained model for the text classifier in the ktrain library? When using:
model = text.text_classifier('bert', (x_train, y_train) , preproc=preproc)
This uses the multilangual pretrained model
However, I want to try out a monolingual model as well. Namely the Dutch one: ''wietsedv/bert-base-dutch-cased', which is also used in other k-train implementations, for example.
However, when trying to use this command in the text classifier it does not work:
model = text.text_classifier('bert', (x_train, y_train) ,
> preproc=preproc, bert_model='wietsedv/bert-base-dutch-cased')
or
model = text.text_classifier('wietsedv/bert-base-dutch-cased', (x_train, y_train), preproc=preproc)
Does anyone how to do this? Thanks!
Upvotes: 1
Views: 2491
Reputation: 489
There are two text classification APIs in ktrain. The first is the text_classifier
API which can be used for a select number of both transformers and non-transformers models. The second is the Transformer
API which can be used with any transformers
model including the one you listed.
The latter is explained in detail in this tutorial notebook and this medium article.
For instance, you can replace MODEL_NAME
with any model you want in the example below:
Example:
# load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)
# build, train, and validate model (Transformer is wrapper around transformers library)
import ktrain
from ktrain import text
MODEL_NAME = 'distilbert-base-uncased' # replace this with model of choice
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)
learner.fit_onecycle(5e-5, 4)
learner.validate(class_names=t.get_classes()) # class_names must be string values
# Output from learner.validate()
# precision recall f1-score support
#
# alt.atheism 0.92 0.93 0.93 319
# comp.graphics 0.97 0.97 0.97 389
# sci.med 0.97 0.95 0.96 396
#soc.religion.christian 0.96 0.96 0.96 398
#
# accuracy 0.96 1502
# macro avg 0.95 0.96 0.95 1502
# weighted avg 0.96 0.96 0.96 1502
Upvotes: 5