Reputation: 2478
I have a pruned and clustered VGG-18 CNN model for CIFAR-10 dataset coded in Python3 and TensorFlow2. The code I found in tutorials for computing it's accuracy feeds the input one at a time due to which all of the 10,000 validation images takes ages. I thought about inputting the validation images as a batch and coded the following:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size = 10000).batch(batch_size)
representative_dataset = tf.data.Dataset.from_tensor_slices(X_test.astype('float32'))
representative_dataset = representative_dataset.shuffle(buffer_size = 10000).batch(batch_size = batch_size)
converter = tf.lite.TFLiteConverter.from_keras_model(clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# batch_size = 64
def representative_dataset_gen():
# for i, samples in enumerate(representative_dataset.take(1)):
for i, samples in enumerate(representative_dataset.take(batch_size)):
yield[samples]
converter.representative_dataset = representative_dataset_gen
# Restrict supported target op specification to INT8-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_model_int = converter.convert()
with tf.io.gfile.GFile("VGG18_Pruned_Clustered_Trained_Quantized.tflite", 'wb') as file:
file.write(tflite_model_int)
# Load TF Lite file and allocate input & output tensors-
tflite_model_file = 'VGG18_Pruned_Clustered_Trained_Quantized.tflite'
interpreter = tf.lite.Interpreter(model_path = tflite_model_file)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.resize_tensor_input(input_details[0]['index'], (batch_size, 32, 32, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (batch_size, num_classes))
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Prepare validation dataset while generating only one sample at a time-
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_dataset = test_dataset.batch(batch_size = batch_size)
# Make predictions using Pruned, Trained and Quantized VGG-18 TF Lite model-
predictions = []
test_labels, test_imgs = [], []
data_sample_count = 0
# for img, label in tqdm(test_batches.take(100)):
for img, label in test_dataset.take(10000):
interpreter.set_tensor(input_index, img)
interpreter.invoke()
predictions.append(interpreter.get_tensor(output_index))
test_labels.append(label.numpy()[0])
test_imgs.append(img)
data_sample_count += 1
print(data_sample_count)
This prints 156 batches of size 64. I have the following two questions:
1.) The last batch has only 16 validation images in it. How can I handle this because for a batch size of 16 images, the code:
interpreter.set_tensor(input_index, img)
gives the error:
ValueError: Cannot set tensor: Dimension mismatch. Got 16 but expected 64 for dimension 0 of input 50.
2.) The predictions made by the 'interpreter' which is stored in the list 'predictions' is of wrong dimension because:
test_imgs[0].shape
# TensorShape([64, 32, 32, 3])
test_labels[0]
# array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=float32)
This means that each element in 'test_imgs' list contains batches of 64 images in it. Whereas, 'test_label's contains only one prediction label instead of the corresponding 64 labels.
How can I fix these errors?
Thanks
Upvotes: 1
Views: 2152
Reputation: 903
To make the graph flexible on the input size, the TensorFlow graph should be design in a such way. For example, making the batch size in the graph should be None instead of 64. After that, while using the converted TFLite model for the inference, the interpreter.resize_tensor_input method should be invoked to update the new shape information with the updated batch size before setting the tensor data.
To get all the test labels for all the batches, the TF graph should have such outputs. Please review your TF graph and make the graph produce the test labels for all the batches to meet your needs.
Upvotes: 3