Karan Owalekar
Karan Owalekar

Reputation: 967

How to use trained tensorflow model in flutter?

I have trained a tensorflow model to predict the next word for an input text. I saved it as an .h5 file.

I can use that model in another python code to predict word as follows:

import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.models import load_model

model = load_model('model.h5')
model.compile(
    loss = "categorical_crossentropy",
    optimizer = "adam",
    metrics = ["accuracy"]
)

data = open("dataset.txt").read()
corpus = data.lower().split("\n")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)

seed_text = input()

sequence_text = tokenizer.texts_to_sequences([seed_text])[0]
padded_sequence = np.array(pad_sequences([sequence_text], maxlen = 11 -1))
predicted = np.argmax(model.predict(padded_sequence))

Is there a way through which I can directly use that model inside flutter, where I can take input from TextField() and by pressing the button, display the predicted word??

Upvotes: 5

Views: 7850

Answers (2)

Saurav Maheshkar
Saurav Maheshkar

Reputation: 210

Steps

  1. Convert the Model into a .tflite model.
# https://www.tensorflow.org/lite/convert/#convert_a_savedmodel_recommended_

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)
  1. Add the tflite Model to the App directory. I usually add the model in a assets/ directory.
android/
assets/
    model.tflite
ios/
lib/
  1. Add tflite as a dependency to pubspec.yaml
dependencies:
  flutter:
    sdk: flutter
  tflite: ^1.0.5
  .
  .
  1. Run Inference in your dart script. For example, the following code snippet is an example script on how to run Inference on an Image where labels.txt is a text file containing the classes:
import 'package:tflite/tflite.dart';
.
.
.

class _MyAppState extends State<MyApp> {
  . . .
  @override
  void initState() {
    super.initState();
    _loading = true;

    loadModel().then((value) {
      setState(() {
        _loading = false;
      });
    });
  }

  classifyImage(File image) async {
    var output = await Tflite.runModelOnImage(
      path: image.path,
      numResults: 2,
      threshold: 0.5,
      imageMean: 127.5,
      imageStd: 127.5,
    );
    setState(() {
      _loading = false;
      _outputs = output;
    });
  }

  loadModel() async {
    await Tflite.loadModel(
      model: "assets/model_unquant.tflite",
      labels: "assets/labels.txt",
    );
  }
  @override
  void dispose() {
    Tflite.close();
    super.dispose();
  }
 . . .
}


SideNote

The tflite plugin doesn't support Text Classification AFAIK, if you want to specifically do Text Classification I'd recommend using the tflite_flutter plugin. Below is the link for a article using the plugin for Text Classification.

Text Classification using TensorFlow Lite Plugin for Flutter

Upvotes: 3

DrunkOnBytes
DrunkOnBytes

Reputation: 29

You cannot use a .h5 file directly in Flutter. You will need to either convert it into a .tflite file and use that or create a REST API.

Converting it to a .tflite file is the easiest. You can check the following article for more details: https://medium.com/analytics-vidhya/run-cnn-model-in-flutter-10c944cadcba

If you want to create a REST API, checkout this article: https://medium.com/analytics-vidhya/deploy-ml-models-using-flask-as-rest-api-and-access-via-flutter-app-7ce63d5c1f3b

Upvotes: 2

Related Questions