Runeater DaWizKid
Runeater DaWizKid

Reputation: 81

Model.fit in keras with multi-label classification

I'm trying to learn how to implement my own dataset on the model seen here: resnet which is just a resnet model written in keras. Within the code they write this line

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

and then use the respective data to 'Convert class vectors to binary class matrices.'

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

and then pass these values into the fit function for the model that was built like so:

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(x_test, y_test),
          shuffle=True,
          callbacks=callbacks)

I believe that I can create the x_train by doing something similar to(assumes i have an array of image paths):

#pseudocode
x_train = nparray
for image in images:
    im = PIL.Image.open(image).asNumpy() 
    x_train.append(im)

Is the above correct?

As for y_train I do not quite understand what is being passed into model.fit, is it an array of one hot encoded arrays? So if I had 3 images containing; a cat and dog, a dog, a cat respectively would the y_train be

[
 [1, 1, 0],#cat and dog
 [0, 1, 0],#dog
 [1, 0, 0]#cat
]

or am I mistaken on this as well?

Upvotes: 0

Views: 1241

Answers (1)

Shubham Panchal
Shubham Panchal

Reputation: 4289

So, model.fit() expects x_train as the features and y_train as the labels for a particular classification problem. I'll be taking into consideration multiclass image classification.

  • x_train: For image classification, this argument will have the shape (num_images, width, height, num_channels ). Where num_images refers to the number of images present in a training batch. See here.

  • y_train: The labels which are one-hot encoded. The required shape is (num_images, num_classes ).

Notice the num_images is common in both the arguments. You need to take care to ensure that there is an equal number of images and labels.

Hope that helps.

Upvotes: 2

Related Questions