Weij
Weij

Reputation: 11

model.fit_generator() shape error

import os
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense

img_width, img_height = 64, 64

train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
nb_train_samples = sum([len(files) for files in os.walk(train_data_dir)])
nb_validation_samples = sum([len(files) for files in os.walk(validation_data_dir)])
nb_epoch = 10


model = Sequential()
model.add(Dense(4096, input_dim = 4096, init='normal', activation='relu'))
model.add(Dense(4,init='normal',activation='softmax'))
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])


train_datagen = ImageDataGenerator(
        rescale=1./255,
        )


test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        color_mode="grayscale",
        target_size=(img_width, img_height),
        batch_size=1,
        class_mode=None)

validation_generator = test_datagen.flow_from_directory(
        validation_data_dir,
        color_mode="grayscale",
        target_size=(img_width, img_height),
        batch_size=1,
        class_mode=None)

model.fit_generator(
        train_generator,
        samples_per_epoch=nb_train_samples,
        nb_epoch=nb_epoch,
        validation_data=validation_generator,
        nb_val_samples=nb_validation_samples)

Everything runs fine until the model.fit_generator() in the coding above. Then it pop out errors like the the followed.

Traceback (most recent call last):
  File "C:/Users/Sam/PycharmProjects/MLP/Testing Code without CNN.py", line 55, in <module>
    nb_val_samples=nb_validation_samples)
  File "C:\Python27\lib\site-packages\keras\models.py", line 874, in fit_generator
    pickle_safe=pickle_safe)
  File "C:\Python27\lib\site-packages\keras\engine\training.py", line 1427, in fit_generator
    'or (x, y). Found: ' + str(generator_output))
Exception: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: [[[[ 0.19215688]

Upvotes: 1

Views: 2800

Answers (2)

pyan
pyan

Reputation: 3707

The problem should be caused by data dimension mismatch. ImageDataGenerator actually loads image files and put into numpy array in shape of (num_image_channel, image_height, image_width). But your first layer is a densely connected layer, which is looking for input data in the shape of 1D array, or 2D array with a number of samples. So essentially you are missing your input layer, which takes the input in the right shape.

Change the following line of code

model.add(Dense(4096, input_dim = 4096, init='normal', activation='relu'))

to

model.add(Reshape((img_width*img_height*img_channel), input_shape=(img_channel, img_height, img_width)))
model.add(Dense(4096, init='normal', activation='relu'))

You have to define img_channel, which is the number of channels in your images. The above code also assumes that your are using dim_ordering of th. If you are using tf input dimension ordering, you would have to change the input reshape layer to

model.add(Reshape((img_width*img_height*img_channel), input_shape=(img_height, img_width, img_channel)))

--- Old answer --

You probably have put training data and validation data into subfolders under train and validation, which isn't supported by Keras. All training data should be in one single folder, same for the validation data.

Please refer to this Keras tutorial for more details.

Upvotes: 1

Maximilian Peters
Maximilian Peters

Reputation: 31679

I am not 100% sure what you are trying to achieve but if you are trying a binary classification of pictures, try setting class_mode to binary. From the documentation:

class_mode: one of "categorical", "binary", "sparse" or None. Default: "categorical". Determines the type of label arrays that are returned: "categorical" will be 2D one-hot encoded labels, "binary" will be 1D binary labels, "sparse" will be 1D integer labels.

The error message is a bit confusing but if you look at the source code, it becomes clearer:

if not hasattr(generator_output, '__len__'):
                    _stop.set()
                    raise Exception('output of generator should be a tuple '
                                    '(x, y, sample_weight) '
                                    'or (x, y). Found: ' + str(generator_output))

Upvotes: 0

Related Questions