Reputation: 1189
I'm trying to do a 8-class classification. Here is the code:
import keras
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense
from keras import applications
from keras.optimizers import SGD
from keras import backend as K
K.set_image_dim_ordering('tf')
img_width, img_height = 48,48
top_model_weights_path = 'modelom.h5'
train_data_dir = 'chCdata1/train'
validation_data_dir = 'chCdata1/validation'
nb_train_samples = 6400
nb_validation_samples = 1600
epochs = 50
batch_size = 10
def save_bottlebeck_features():
datagen = ImageDataGenerator(rescale=1. / 255)
model = applications.VGG16(include_top=False, weights='imagenet', input_shape=(48,48,3))
generator = datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical',
shuffle=False)
bottleneck_features_train = model.predict_generator(
generator, nb_train_samples // batch_size)
np.save(open('bottleneck_features_train', 'wb'),bottleneck_features_train)
generator = datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical',
shuffle=False)
bottleneck_features_validation = model.predict_generator(
generator, nb_validation_samples // batch_size)
np.save(open('bottleneck_features_validation', 'wb'),bottleneck_features_validation)
def train_top_model():
train_data = np.load(open('bottleneck_features_train', 'rb'))
train_labels = np.array([0] * (nb_train_samples // 8) + [1] * (nb_train_samples // 8) + [2] * (nb_train_samples // 8) + [3] * (nb_train_samples // 8) + [4] * (nb_train_samples // 8) + [5] * (nb_train_samples // 8) + [6] * (nb_train_samples // 8) + [7] * (nb_train_samples // 8))
validation_data = np.load(open('bottleneck_features_validation', 'rb'))
validation_labels = np.array([0] * (nb_train_samples // 8) + [1] * (nb_train_samples // 8) + [2] * (nb_train_samples // 8) + [3] * (nb_train_samples // 8) + [4] * (nb_train_samples // 8) + [5] * (nb_train_samples // 8) + [6] * (nb_train_samples // 8) + [7] * (nb_train_samples // 8))
train_labels = keras.utils.to_categorical(train_labels, num_classes = 8)
validation_labels = keras.utils.to_categorical(validation_labels, num_classes = 8)
model = Sequential()
model.add(Flatten(input_shape=train_data.shape[1:]))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(8, activation='softmax'))
sgd = SGD(lr=1e-2, decay=0.00371, momentum=0.9, nesterov=False)
model.compile(optimizer=sgd,
loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_data, train_labels,
epochs=epochs,
batch_size=batch_size,
validation_data=(validation_data, validation_labels))
model.save_weights(top_model_weights_path)
save_bottlebeck_features()
train_top_model()
I've added the full list of error here:
Traceback (most recent call last):
File "<ipython-input-14-1d34826b5dd5>", line 1, in <module>
runfile('C:/Users/rajaramans2/codes/untitled15.py', wdir='C:/Users/rajaramans2/codes')
File "C:\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 866, in runfile
execfile(filename, namespace)
File "C:\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/rajaramans2/codes/untitled15.py", line 71, in <module>
train_top_model()
File "C:/Users/rajaramans2/codes/untitled15.py", line 67, in train_top_model
validation_data=(validation_data, validation_labels))
File "C:\Anaconda3\lib\site-packages\keras\models.py", line 856, in fit
initial_epoch=initial_epoch)
File "C:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1449, in fit
batch_size=batch_size)
File "C:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1317, in _standardize_user_data
_check_array_lengths(x, y, sample_weights)
File "C:\Anaconda3\lib\site-packages\keras\engine\training.py", line 235, in _check_array_lengths
'and ' + str(list(set_y)[0]) + ' target samples.')
ValueError: Input arrays should have the same number of samples as target arrays. Found 1600 input samples and 6400 target samples.
The "ValueError: Input arrays should have the same number of samples as target arrays. Found 1600 input samples and 6400 target samples" pops up. Kindly help with the solution and the necessary modifications to the code. Thanks in advance.
Upvotes: 16
Views: 47383
Reputation: 666
The problem in this case is on this line
validation_labels = np.array([0] * (nb_train_samples // 8) + [1] * (nb_train_samples // 8) + [2] * (nb_train_samples // 8) + [3] * (nb_train_samples // 8) + [4] * (nb_train_samples // 8) + [5] * (nb_train_samples // 8) + [6] * (nb_train_samples // 8) + [7] * (nb_train_samples // 8))
there is certainly a better way of writing this since now every occurence of nb_train_samples
should be replaced with nb_validation_samples
Upvotes: 0
Reputation: 21
It is not about len(X_train) != len(y_train).
Split the data into equal size for training and testing(validation). Make sure that the input data size is even. If not try to trim the data by omitting the last observation in the input data.
train_test_split(X,y, test_size = 0.5, random_state=42)
This is working for me.
Upvotes: 1
Reputation: 5373
I know you have an answer but for other travelers make sure your train data is divisible by your batch_size.
Upvotes: 0
Reputation: 86600
Looks like you have 1600 examples for training. And your 8 classes are not separated in samples, so you have an array with 8 x 1600 = 6400 values.
That array must be something such as (1600,8). That is: 1600 samples with 8 possible classes.
Now you need to know how your train_labels
array is organized. Maybe a simple reshape((1600,8))
is enough, if the array is properly ordered.
If not, you have to organize it yourself in 1600 samples of eight labels.
Upvotes: 6
Reputation: 561
It looks like the number of examples in X_train i.e. train_data doesn't match with the number of examples in y_train i.e. train_labels. Can you double check it? And, in the future, please attach the full error since it helps in debugging the issue.
Upvotes: 23