Reputation: 573
I am trying to trin an image classifier with Keras and i keep getting the error:
InvalidArgumentError: logits and labels must be broadcastable: logits_size=[32,4] labels_size=[32,2] [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at :2) ]] [Op:__inference_train_function_10520]
Function call stack: train_function
I am creating my Model like this:
base_model = ResNet50(include_top=False, weights='imagenet')
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(4, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics = ['accuracy'])
data_folder = os.path.join("data", "train_min")
test_folder = os.path.join("data", "test_min")
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory(data_folder,
target_size = (224, 224),
batch_size = 32,
class_mode = 'categorical')
test_set = test_datagen.flow_from_directory(test_folder,
target_size = (224, 224),
batch_size = 32,
class_mode = 'categorical')
After creating the training_set and the test_set i get the message
Found 3520 images belonging to 2 classes. (training_set)
Found 480 images belonging to 2 classes. (test_set)
So loading images works fine, i guess.
But when i try to execute this code:
model.fit_generator(training_set,
steps_per_epoch = 8000,
epochs = 5,
validation_data = test_set,
validation_steps = 200)
I am getting the error i already showed you above:
InvalidArgumentError: logits and labels must be broadcastable: logits_size=[32,4] labels_size=[32,2] [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at :2) ]] [Op:__inference_train_function_10520]
Function call stack: train_function
How do i change the label size? Isn't Labeling done automatically when i create the training_set? What are logits?
Upvotes: 0
Views: 1448
Reputation: 21
try change 4 to 2 in line: predictions = Dense(4, activation='softmax')(x)
Upvotes: 2