NoLand'sMan
NoLand'sMan

Reputation: 574

Building a tflite model for multi class classification

I have read multiple codelabs in which Google classifies images belonging to one class. What if I need to use 2 or more classes. For example, if I want to classify whether an image contains a fruit or a vegetable, and then classify which type of fruit or vegetable it is.

Upvotes: 2

Views: 1282

Answers (1)

Shubham Panchal
Shubham Panchal

Reputation: 4289

You can easily train a Convolutional Neural Network ( CNN ) using TensorFlow ( specifically using Keras ). There are tons of examples on the internet. See here and here.

Next, we convert the Keras saved model ( .h5 file ) to a .tflite file using tf.lite.TFLiteConverter,

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

See here.

Now, in Android, we take a Bitmap image and convert it to a float[][][][],

private float[][][][] convertImageToFloatArray ( Bitmap image ) {
   float[][][][] imageArray = new   float[1][modelInputDim][modelInputDim][1] ;
   for ( int x = 0 ; x < modelInputDim ; x ++ ) {
       for ( int y = 0 ; y < modelInputDim ; y ++ ) {
           float R = ( float )Color.red( image.getPixel( x , y ) );
           float G = ( float )Color.green( image.getPixel( x , y ) );
           float B = ( float )Color.blue( image.getPixel( x , y ) );
           double grayscalePixel = (( 0.3 * R ) + ( 0.59 * G ) + ( 0.11 * B )) / 255;
           imageArray[0][x][y][0] = (float)grayscalePixel ;
       }
   }
   return imageArray ;
}

Where modelInputDim is the model's input size for the image. The above snippet converts the RGB image to a grayscale image.

Now, we perform the final inference,

private int modelInputDim = 28 ;
private int outputDim = 3 ;

private float[] performInference(Bitmap frame , RectF cropImageRectF ) {
   Bitmap croppedBitmap = getCroppedBitmap( frame , cropImageRectF ) ;
   Bitmap croppedFrame = resizeBitmap( croppedBitmap );
   float[][][][] imageArray = convertImageToFloatArray( croppedFrame ) ;
   float[][] outputArray = new float[1][outputDim] ;
   interpreter.run( imageArray , outputArray ) ;
   return outputArray[0] ;
}

I have prepared a collection of Android apps which utilize TFLite models in Android. See here.

Upvotes: 1

Related Questions