Reputation: 781
I have re-trained a mobilenet-v1 image classification model from Tensorflow Hub, and converted it using toco for inference using Tensorflow Lite.
However, when I run inference using the tflite model, it requires a different input size than what I specified with --input_shape
.
How can I re-train a mobilenetv1 quantized model on my own data?
Here a the steps that I attempted:
Retrain the mobilenet v1 quantized model on TF Hub using the dataset above
python retrain.py \
--bottleneck_dir="${IMAGE_DIR}"/tf_files/bottlenecks/ \
--how_many_training_steps=1000 \
--model_dir="${IMAGE_DIR}"/tf_files/models/mobilenet_v1_050_224 \
--summaries_dir="${IMAGE_DIR}"/tf_files/training_summaries/mobilenet_v1_050_224/ \
--output_graph="${IMAGE_DIR}"/tf_files/retrained_mobilenet_v1_050_224.pb \
--output_labels="${IMAGE_DIR}"/tf_files/retrained_labels.txt \
--tfhub_module=https://tfhub.dev/google/imagenet/mobilenet_v1_050_224/quantops/classification/1 \
--image_dir="${IMAGE_DIR}"/tf_files/flower_photos
Verify that the model is properly trained, and input/output tensor name is correct
python label_image.py \
--graph="${IMAGE_DIR}"/tf_files/retrained_mobilenet_v1_050_224.pb \
--labels="${IMAGE_DIR}"/tf_files/retrained_labels.txt \
--input_layer=Placeholder \
--output_layer=final_result \
--input_height=224 --input_width=224 \
--image="${IMAGE_DIR}"/tf_files/flower_photos/daisy/21652746_cc379e0eea_m.jpg
Convert the model to tflite
toco \
--input_file="${IMAGE_DIR}"/tf_files/retrained_mobilenet_v1_050_224.pb \
--output_file="${IMAGE_DIR}"/tf_files/mobilenet_v1_050_224_quant.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--input_shape=1,224,224,3 \
--input_array=Placeholder \
--output_array=final_result \
--inference_type=QUANTIZED_UINT8 \
--input_data_type=FLOAT
Although I specified --input_shape=1,224,224,3
, when I run the inference, I got an error:
java.lang.IllegalArgumentException: DataType (1) of input data does not match with the DataType (3) of model inputs.
Upvotes: 0
Views: 3130
Reputation: 781
Fast forward to 2020, the easiest way to train a TF Lite image classification model now is to use TF Lite Model Maker. https://www.tensorflow.org/lite/tutorials/model_maker_image_classification
The output TF Lite model can be drag-and-drop into Android Studio with the ML Model Binding plugin. See the end-to-end flow in this video. https://www.youtube.com/watch?v=s_XOVkjXQbU
Upvotes: 0
Reputation: 2157
Maybe I am wrong,
but datatype error does not seem to be a problem with the shape of the input data but more of the datatype.
If you quantize a model, this means you change the datatype from float32 to int8.
Depending on how you run your model there are different types of quantization.
There are also other quantization methods, but what I am aiming for is: If you did a full quantization, either you have a quantization layer included which does the transformation from float32 to int8 for you or, your model expects int8 inputs.
edit: I just saw that you determine the input as FLOAT. Maybe float32 would be the correct term. At least something is off with your input datatype vs first layer input type.
You can use a tool like Netron to look at your input layer and see what is expected. The tool can also be used to identify how your network was quantized.
Good luck and stay save.
Upvotes: 0
Reputation: 1685
--input_data_type="" or --input_data_types="" >> Input array type, if not already provided in the graph. Typically needs to be specified when passing arbitrary arrays to --input_arrays.
In my case it was not needed (i used MobileNet_V2 pretrained model).
You have to add few more arguments (--mean_value --std_value --default_ranges_min and --default_ranges_max ) in your command for quatization.
As mentioned in the gihub documentation page, the following command works for me
bazel-bin/tensorflow/contrib/lite/toco/toco --input_file=~/tf_files/retrained_graph_mobileNet_v2_100_224.pb --output_file=~/tf_files/retrained_graph_mobileNet_q_v2_100_224.tflite --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --inference_type=QUANTIZED_UINT8 --input_shape=1,224,224,3 --input_array=Placeholder --output_array=final_result --mean_value=128 --std_value=128 --default_ranges_min=0 --default_ranges_max=6
Upvotes: 0