Reputation: 357
I am trying to predict several million images with my trained model using a predict_generator in python 3 with keras and tensorflow as backend. The generator and the model predictions work, however, some images in the directory are broken or corrupted and cause the predict_generator to stop and throw an error. Once the image is removed it works again until the next corrupted/broken image gets fed through the function.
Since there are so many images it is not feasible to run a script to open every image and delete the ones that are throwing an error. Is there a way to incorporate a "skip image if broken" argument into the generator or flow from directory function?
Any help is greatly appreciated!
Upvotes: 1
Views: 732
Reputation: 1141
Since it happens during prediction, if you skip any image or batch, you need to keep track of which images are skipped, so that you can correctly map the prediction scores to the image file name.
Based on this idea, my DataGenerator is implemented with a valid image index tracker. In particular, focus on the variable valid_index
where index of valid images are tracked.
class DataGenerator(keras.utils.Sequence):
def __init__(self, df, batch_size, verbose=False, **kwargs):
self.verbose = verbose
self.df = df
self.batch_size = batch_size
self.valid_index = kwargs['valid_index']
self.success_count = self.total_count = 0
def __len__(self):
return int(np.ceil(self.df.shape[0] / float(self.batch_size)))
def __getitem__(self, idx):
print('generator is loading batch ',idx)
batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
self.total_count += batch_df.shape[0]
# return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
x = load_batch_image_to_arrays(batch_df['image_file_names'])
# filter out corrupted images
tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
u is not None]
# boundary case. # all image failed, return another random batch
if len(tmp) == 0:
print('[ERROR] All images loading failed')
# based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
# Keras will automatically find the next batch if it returns None
return None
print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
self.success_count += len(tmp)
x, batch_index = zip(*tmp)
x = np.stack(x) # list to np.array
self.valid_index[idx] = batch_index
# follow preprocess input function provided by keras
x = resnet50_preprocess(np.array(x, dtype=np.float))
return x
def on_epoch_end(self):
print('total image count', self.total_count)
print('successful images count', self.success_count)
self.success_count = self.total_count = 0 # reset count after one epoch ends.
During prediction.
predictions = model.predict_generator(
generator=data_gen,
workers=10,
use_multiprocessing=False,
max_queue_size=20,
verbose=1
).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions
Upvotes: 0
Reputation: 2522
There's no such argument in ImageDataGenerator
and neither in flow_from_directory
method as you can see int the Keras docs for both (here and here). One workaround would be to extend the ImageDataGenerator
class and overload the flow_from_directory
method to check wether the image is corrupted or not before yeld it in the generator. Here you can find it's source code.
Upvotes: 1