Reputation: 9544
I'm trying to learn how to use some ML stuff for Android. I got the Text Classification demo working and seems to work fine. So then I tried creating my own model.
The code I used to create my own model was this:
import numpy as np
import os
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
spec = model_spec.get('mobilebert_classifier')
train_data = DataLoader.from_csv(
filename='/path to file/train.csv',
text_column='sentence',
label_column='label',
model_spec=spec,
is_training=True)
model = text_classifier.create(train_data, model_spec=spec, epochs=10)
model.export(export_dir='average_word_vec')
The code appeared to run fine and it created a model.tflite
file for me. I then replaced the demo tflite
file with mine. But when I run the demo I get the following error:
java.lang.AssertionError: Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32.
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.initJniWithByteBuffer(Native Method)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.access$100(NLClassifier.java:67)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier$2.createHandle(NLClassifier.java:223)
at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromBufferAndOptions(NLClassifier.java:219)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFileAndOptions(NLClassifier.java:175)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFile(NLClassifier.java:150)
at org.tensorflow.lite.examples.textclassification.client.TextClassificationClient.load(TextClassificationClient.java:44)
at org.tensorflow.lite.examples.textclassification.MainActivity.lambda$onStart$1$MainActivity(MainActivity.java:67)
at org.tensorflow.lite.examples.textclassification.-$$Lambda$MainActivity$eJaQnJq74KcmPEczFE5swJIGydg.run(Unknown Source:2)
What am I missing?
Upvotes: 0
Views: 523
Reputation: 26
In your codes you trained a MobileBERT model, but saved to the path of average_word_vec? spec = model_spec.get('mobilebert_classifier') model.export(export_dir='average_word_vec')
One posssiblity is: you use the model of average_word_vec, but add MobileBERT metadata, thus the preprocessing doesn't match.
Could you follow the Model Maker tutorial and try again? https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb Make sure change the export path.
Upvotes: 1