Marcin
Marcin

Reputation: 1124

Design of the binary classifier based on CNN

I have designed CNN to classify images for the purpose of automatic quality control. Input images are 320 x 320 px. I have 5 conv layers, FC layer with 512 outputs and final layer with only two outputs: 'good' or 'bad'. Quality control has to be done with 1.0 accuracy. I am using tensorflow.

I am beginner in CNN and I have a problem in evaluation of my model. Although I do get 1.0 accuracy on training set and sometimes also on validation set, I am worried about the values of cost functions. My model outputs very big logits and if I softmax these logits I always get 100% probability of 'good' or 'bad'. In consequence, if my model correctly predicts the example, the cost (calculated as cross_entropy_with_logits) is 0. If all training examples are predicted correctly, than weights do not change anymore, and my model does not improve performance on validation set.

Here are example outputs of my model (batch containing 10 examples):

  Logits
    [[ 2169.41455078  2981.38574219]
 [ 2193.54492188  3068.97509766]
 [ 2185.86743164  3060.24047852]
 [ 2305.94604492  3198.36083984]
 [ 2202.66503906  3136.44726562]
 [ 2305.78076172  2976.58081055]
 [ 2248.13232422  3130.26123047]
 [ 2259.94726562  3132.30200195]
 [ 2290.61303711  3098.0871582 ]
 [ 2500.9609375   3188.67456055]]

    Softmax:
   [[ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]]

    cost calculated with tf.nn.softmax_cross_entropy_with_logits
     [ 811.97119141    0.          874.37304688    0.          933.78222656
  670.80004883    0.            0.          807.47412109    0.        ]

What do you think is the problem here? My CNN is too complicated for the application and outputs 100% probabilities? My CNN is simply overfitting? Do you dropout would help?

Upvotes: 1

Views: 344

Answers (1)

Da Tong
Da Tong

Reputation: 2026

The problem is overfitting. To solve it, there are some ideas:

  1. increase the training dataset, either collecting more data or generate transformed images based on the existing dataset.
  2. involve regularizations, L1/L2 regularization, batch norm, dropout will help.
  3. consider using pre-trained model, which is so-called Transfer Learning, refer to this tutorial.

Upvotes: 1

Related Questions