Reputation: 316
I'm building a generator for Keras to be able to load my dataset images since it's a bit big for my ram.
I built the generator like this:
# import the necessary packages
import tensorflow
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import pandas as pd
from tqdm import tqdm
#loading
path_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset-
images_improved.txt"
df = pd.read_csv(path_to_txt ,sep='\t')
arr = np.array(df)
#epochs and steps:
NUM_TRAIN_IMAGES = 0
NUM_EPOCHS = 30
def image_generator(arr, bs, mode="train", aug=None):
while True:
images = []
labels = []
for row in arr:
if len(images) < bs:
img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/" +
row[0]),(224,224)))
images.append(img)
labels.append([row[2]])
NUM_TRAIN_IMAGES += 1
else:
break
if aug is not None:
(images, labels) = next(aug.flow(np.array(images),labels,
batch_size=bs))
obj = OneHotEncoder()
values = obj.fit_transform(labels).toarray()
yield (np.array(images), labels)
I then call fit_generator from a Sequential model (the cnn worked until I got OOM error)
#create the augmentation function:
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
horizontal_flip=True, fill_mode="nearest")
#create the generator:
gen = image_generator(arr, bs = 32, mode = "train", aug = aug)
history = model.fit_generator(image_generator,
steps_per_epoch = NUM_TRAIN_IMAGES,
epochs = NUM_EPOCHS)
And from here, I get this error:
# Create generator from NumPy or EagerTensor Input.
--> 377 num_samples = int(nest.flatten(data)[0].shape[0])
378 if batch_size is None:
379 raise ValueError('You must specify `batch_size`')
AttributeError: 'function' object has no attribute 'shape'
Upvotes: 0
Views: 946
Reputation: 1002
I see two major errors here.
Firstly your generator function is not memory efficient. Because you load all images at first (while loop). You should iterate over image files and inside loop yield np.array of image with label.
Secondly you pass generator function name to fit_generator when you should use its returned object - gen.
Upvotes: 1