Reputation: 967
I have trained a tensorflow model to predict the next word for an input text. I saved it as an .h5 file.
I can use that model in another python code to predict word as follows:
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
model = load_model('model.h5')
model.compile(
loss = "categorical_crossentropy",
optimizer = "adam",
metrics = ["accuracy"]
)
data = open("dataset.txt").read()
corpus = data.lower().split("\n")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
seed_text = input()
sequence_text = tokenizer.texts_to_sequences([seed_text])[0]
padded_sequence = np.array(pad_sequences([sequence_text], maxlen = 11 -1))
predicted = np.argmax(model.predict(padded_sequence))
Is there a way through which I can directly use that model inside flutter, where I can take input from TextField() and by pressing the button, display the predicted word??
Upvotes: 5
Views: 7850
Reputation: 210
.tflite
model.# https://www.tensorflow.org/lite/convert/#convert_a_savedmodel_recommended_
import tensorflow as tf
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
assets/
directory.android/
assets/
model.tflite
ios/
lib/
pubspec.yaml
dependencies:
flutter:
sdk: flutter
tflite: ^1.0.5
.
.
labels.txt
is a text file containing the classes:import 'package:tflite/tflite.dart';
.
.
.
class _MyAppState extends State<MyApp> {
. . .
@override
void initState() {
super.initState();
_loading = true;
loadModel().then((value) {
setState(() {
_loading = false;
});
});
}
classifyImage(File image) async {
var output = await Tflite.runModelOnImage(
path: image.path,
numResults: 2,
threshold: 0.5,
imageMean: 127.5,
imageStd: 127.5,
);
setState(() {
_loading = false;
_outputs = output;
});
}
loadModel() async {
await Tflite.loadModel(
model: "assets/model_unquant.tflite",
labels: "assets/labels.txt",
);
}
@override
void dispose() {
Tflite.close();
super.dispose();
}
. . .
}
The tflite plugin doesn't support Text Classification AFAIK, if you want to specifically do Text Classification I'd recommend using the tflite_flutter
plugin. Below is the link for a article using the plugin for Text Classification.
Text Classification using TensorFlow Lite Plugin for Flutter
Upvotes: 3
Reputation: 29
You cannot use a .h5 file directly in Flutter. You will need to either convert it into a .tflite file and use that or create a REST API.
Converting it to a .tflite file is the easiest. You can check the following article for more details: https://medium.com/analytics-vidhya/run-cnn-model-in-flutter-10c944cadcba
If you want to create a REST API, checkout this article: https://medium.com/analytics-vidhya/deploy-ml-models-using-flask-as-rest-api-and-access-via-flutter-app-7ce63d5c1f3b
Upvotes: 2