Mehdi Karbalai
Mehdi Karbalai

Reputation: 11

How to read output of Keras Handwriting Recongition Model in TF Lite Android?

I'm trying to implement handwriting text recogontition in my Android App. I found TensorFlow to be a doable solution, so I've tried to create a .tflite Model from the Handwriting Recognition Model from Keras The tutorial states that it is fully compatible with TF Lite I managed to create the .tflite model and then in Android intialize the Interpreter with the model. I then ran the Interpreter with a ByteBuffer of a bitmap and the output is a shape of [1,32,81], which is a array of floats. As far as i know the output should just be a String; the prediction text of the given input. How can I get/decode the output to the String I need?

I had a few problems

  1. Converting the model to a .tflite but i managed to do it using certain flags as follows:
converter = tf.lite.TFLiteConverter.from_keras_model(prediction_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter._experimental_lower_tensor_list_ops = False
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()
open('textRecognitionModel.tflite', 'wb').write(tf_lite_model)
  1. According to the docs of TF Lite you have to use the following dependencies
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
// This dependency adds the necessary TF op support.
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly-SNAPSHOT'

After finally creating a .tflite model file, I then added it to the assets directory of my android app and tried importing it. However, it would crash with no error message, apparently a memory failure. I updated the libraries to the latest version:

"org.tensorflow:tensorflow-lite:2.11.0"
"org.tensorflow:tensorflow-lite-select-tf-ops:2.11.0"

And converted my model to ByteBuffer as follows (I'm not sure if i'm doing it right regarding the native order logic):

// fileName is the name of the model file in the assets dir
val inputStream = assetManager.open(filename)
val output = ByteArrayOutputStream()
inputStream.copyTo(output, 1024)
val file = output.toByteArray()
val bb = ByteBuffer.allocateDirect(file.size)
bb.order(ByteOrder.nativeOrder())
bb.put(file)
return bb

And finally the initialization of the Interpreter API is finally working. I then run the interpreter on a ByteBuffer of a Bitmap. So I'm expecting that the model will read the input and give prediction text (a String) as output. However, the output is a [1,32,81] shape, so i created an array to read the output and ran the Interpreter on it:

val output = Array(1) {
    Array(32) {
        FloatArray(81)
    }
}
// byteBuffer: ByteBuffer of bitmap
interpreter.run(byteBuffer, output)

And the output is an array of floats which I don't understand what this means. Shouldn't it just be a String? I've attached a screenshot of the output arrayoutput screenshot

Can someone please help me?? I would highly appreciate any tips or solutions :)

Upvotes: 0

Views: 338

Answers (1)

MSS
MSS

Reputation: 3633

Before converting the prediction_model to tflite format, you need to add a custom layer at the end and then convert it into tflite format.

prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
) # This line is present in the handwriting_recognition notebook.

def CTCDecoder():
  def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_length]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        #print(res)
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

  return tf.keras.layers.Lambda(decode_batch_predictions, name='decode')

decoded_pred_model = keras.models.Model(prediction_model.input, outputs=CTCDecoder()(prediction_model.output))

Now you can convert decoded_pred_model to your tflite format and use it. CTCDecoder is the custom layer added on top of prediction_model.output to decode the predictions with shape [1,32,81] into texts.

Upvotes: 0

Related Questions