conch_shell_ninja
conch_shell_ninja

Reputation: 41

What's keeping this simple CNN from classifying whether an image contains a cat or not?

Instead of "cats vs dogs", I'm trying "cats vs. everything else" on a brand new network (no transfer learning) using a large number of random internet images I've sorted into "cat" or "no cat" categories.

Unfortunately, my network won't seem to train itself past random for this task.

My networks have been basic multi-layer CNNs, with a single large dense layer and then one single sigmoid neuron at the end to output 0 or 1 to signify "yes cat" or "no cat". My "cat" images consist of many cropped images of various breeds of cat in many poses and angles, in differing environments and backgrounds. I've tried various forms of data augmentation and image weighting and data/validation shuffling but still can't get a useful network out of it.

What am I missing? Is there something wrong with my architecture or approach? I'm a machine learning newbie, and using Keras over Tensorflow.

My network architecture is as follows:

Input 320x320 RGB image into network, then:

1a) Conv layer, 32 filters, 3x3
1b) Relu and max pooling 2x2.

2a) Conv layer, 32 filters, 3x3
2b) Relu and max pooling 2x2

3a) Conv layer, 64 filters, 3x3
3b) Relu and max pooling 2x2

4) Flatten

5a) Dense 64 neurons
5b) Relu
5c) Dropout 0.5
5d) Final dense 1 neuron, sigmoid activation. (0 = cat in image, 1 = no cat in image)

I'm using binary cross-entropy, and image rotation/shift/flip/etc. for data augmentation. My dataset is highly imbalanced with 1 cat picture for every 5 non-cat pictures. I've reserved 25% of my dataset is reserved for validation, with the same 1:5 imbalance. I have weighted the fit generator to give cat images a 5x higher weighting as a result.

Is there a problem with the architecture, initialization, etc. or am I simply severely underestimating the time it takes to train a CNN from scratch on modern hardware?

Upvotes: 1

Views: 367

Answers (1)

Hugues Fontenelle
Hugues Fontenelle

Reputation: 5435

The "cats vs dogs" example often uses VGG16 as in the fast.ai course.

Its 16-layer network was used by the VGG team in the ILSVRC-2014 ImageNet competition. See one implementation in Keras.

You could probably download the weights, then apply that to your problem after downscaling your images to (3 channels x 224x224 pixels).

Upvotes: 1

Related Questions