Vahid the Great
Vahid the Great

Reputation: 483

Making prediction in Keras using a custom generator

I have an image-classifier model in TensorFlow which I wanna make predictions with. I have created a custom generator to avoid loading it all in the RAM at the same time.

def load_and_preprocess_image(url_path_x):
        with requests.Session() as s:
            request_x=s.get(url_path_x).content
        img = Image.open(BytesIO(request_x))
        img = img.convert('RGB')
        img = img.resize((224,224), Image.NEAREST)
        img = tensorflow.keras.preprocessing.image.img_to_array(img)
        return(img)

def prediction_generator(urls_x):
    for url_x in urls_x:
        try:
            yield load_and_preprocess_image(path_x=url_x, is_url=True).reshape(1,224,224,3)
        except:
            yield load_and_preprocess_image(path_x=dummy_image_path, is_url=True).reshape(1,224,224,3)

my_path_gen = prediction_generator(df['url_path_column'])
preds_probas = model_i.predict(my_path_gen, batch_size=1, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)

However, it seems that my code is consuming extensive RAM as if the code is loading all the images into the RAM at the same time. Is there anything wrong with my custom generator?

Upvotes: 1

Views: 234

Answers (2)

Vahid the Great
Vahid the Great

Reputation: 483

The generators are fine guys and they aren't using excessive RAM. The issue was somewhere else.

Anyways, I'm leaving the question here so that maybe the code be useful to someone.

Upvotes: 1

Vahid
Vahid

Reputation: 1417

I'm definetly not the expert in this topic but shouldn't the generator work with __len__ and __getitem__ ? From this link

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

Upvotes: 1

Related Questions