Reputation: 15
I have followed this TensorFlow tutorial to classify images using transfer learning approach. Using almost 16,000 manually classified images (with about 40/60 split of 1/0) added on top of the pre-trained MobileNet V2 model, my model achieved 96% accuracy on the hold out test set. I then saved the resulting model.
Next, I would like to use this trained model to classify new images. To do so, I have adapted one of the portions of the tutorial's code (in the end where it says #Retrieve a batch of images from the test set) in the way described below. The code works, however, it only processes one batch of 32 images and that's it (there are hundreds of images in the source folder). What am I missing here? Please advise.
# Import libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import preprocessing
from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib.pyplot as plt
import numpy as np
import os
# Load saved model
model = tf.keras.models.load_model('/model')
# Re-compile model
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# Define paths
PATH = 'Data/'
new_dir = os.path.join(PATH, 'New_images') # New_images must contain at least one class (sub-folder)
IMG_SIZE = (640, 640)
BATCH_SIZE = 32
new_dataset = image_dataset_from_directory(new_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE)
# Retrieve a batch of images from the test set
image_batch, label_batch = new_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)
print('Predictions:\n', predictions.numpy())
len(new_dataset) # equals 25, i.e., there are 25 batches
Upvotes: 0
Views: 1091
Reputation: 379
Replace this code:
# Retrieve a batch of images from the test set
image_batch, label_batch = new_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
with this one:
predictions = model.predict(new_dataset,batch_size=BATCH_SIZE).flatten()
tf.data.Dataset
objects can be directly passed to the method predict()
. Reference
Upvotes: 2