user13921845
user13921845

Reputation: 1

Why is my binary classification model not learning, even to overfit?

I have the following model, using tensorflow 2.2.0 with keras:

def get_model(input_shape):
  model = keras.Sequential()
  
  model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  
  model.add(Conv2D(64, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Flatten())

  model.add(Dense(64, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  return model

The shape of the input is (25, 25, 4) - a 3-dimensional image, 25x25px, with 4 channels. The model does not learn - it won't even overfit! I am trying to fit using the following incantation:

model.compile(optimizer='sgd', metrics=['accuracy'], loss='binary_crossentropy')
model.fit(trainX, trainY, validation_split=0.2, epochs=10, batch_size=50)

I have also tried changing the optimizer to be sgd with the same results and have tried varying batch sizes (including 1). An example of training for 10 epochs:

Epoch 1/10
763/763 [==============================] - 4s 5ms/step - loss: 0.6935 - accuracy: 0.5045 - val_loss: 0.6937 - val_accuracy: 0.5031
Epoch 2/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5020 - val_loss: 0.6946 - val_accuracy: 0.4972
Epoch 3/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5016 - val_loss: 0.6932 - val_accuracy: 0.4984
Epoch 4/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6934 - accuracy: 0.5020 - val_loss: 0.6932 - val_accuracy: 0.4986
Epoch 5/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5027 - val_loss: 0.6934 - val_accuracy: 0.4972
Epoch 6/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5051 - val_loss: 0.6946 - val_accuracy: 0.5019
Epoch 7/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6932 - val_accuracy: 0.4959
Epoch 8/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6934 - val_accuracy: 0.5056
Epoch 9/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5040 - val_loss: 0.6931 - val_accuracy: 0.5009
Epoch 10/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5018 - val_loss: 0.6931 - val_accuracy: 0.5020
<tensorflow.python.keras.callbacks.History at 0x7f761a0856d8>

For what it's worth, the data is almost certainly not the problem - I have tried other machine learning methods such as random forests and gradient boosting and they are able to overfit just fine.

Am I missing something fundamental here?

Edit: setting the activation of conv layers to relu does not help. The below output is with relu:


Epoch 1/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6936 - accuracy: 0.4990 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 2/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5026 - val_loss: 0.6931 - val_accuracy: 0.5043
Epoch 3/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 4/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5004 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 5/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.4992 - val_loss: 0.6932 - val_accuracy: 0.5029
Epoch 6/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 7/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 8/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5001 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 9/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5029 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 10/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5012 - val_loss: 0.6931 - val_accuracy: 0.5029
<tensorflow.python.keras.callbacks.History at 0x7f29766804a8>

I also tried changing the labels to be categorical and using categorical_crossentropy, to no avail.

Edit 2: The same behaviour persists over more epochs, with activation correctly set.

Model:

def get_model(input_shape):
  model = keras.Sequential()
  
  model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  
  model.add(Flatten())

  model.add(Dense(64, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  return model

Output:

Epoch 1/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6937 - accuracy: 0.4998 - val_loss: 0.6931 - val_accuracy: 0.5008
...
Epoch 243/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 244/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 245/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5014 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 246/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5035 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 247/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 248/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5026 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 249/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5018 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 250/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029

Data sample:

display(trainX[0])
display(trainX[0].shape)
---
array([[[-0.81307793, -0.80876915, -0.80270227, -0.81340067],
        [-0.81323822, -0.80901267, -0.80424022, -0.81004681],
        [-0.80974839, -0.80952621, -0.80894936, -0.81924987],
        [-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
        [-0.82926863, -0.82535357, -0.81962295, -0.82940024],
        [-0.82911602, -0.82669005, -0.81815252, -0.82725751],
        [-0.82717653, -0.82594539, -0.81691338, -0.82605227],
        [-0.82584266, -0.82452835, -0.81556359, -0.82556375],
        [-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
        [-0.82222369, -0.82112803, -0.82649334, -0.83150323]],

       [[-0.81323822, -0.80901267, -0.80424022, -0.81004681],
        [-0.81339844, -0.80925606, -0.80577279, -0.80666623],
        [-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
        [-0.81916858, -0.81916656, -0.82128286, -0.82628101],
        [-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
        [-0.82926995, -0.82692302, -0.81963519, -0.82401759],
        [-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
        [-0.82599791, -0.82476263, -0.81705573, -0.82230957],
        [-0.82540639, -0.82289927, -0.81926792, -0.81915487],
        [-0.8223804 , -0.82136435, -0.82794484, -0.82829943]],

       [[-0.80974839, -0.80952621, -0.80894936, -0.81924987],
        [-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
        [-0.80639256, -0.81028192, -0.81510641, -0.82501505],
        [-0.81572868, -0.81966765, -0.82580199, -0.83511534],
        [-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
        [-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
        [-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
        [-0.82261689, -0.82525665, -0.82162309, -0.83123616],
        [-0.82202021, -0.82339567, -0.82381013, -0.82815378],
        [-0.818968  , -0.82186268, -0.83238638, -0.83708633]],

       [[-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
        [-0.81916858, -0.81916656, -0.82128286, -0.82628101],
        [-0.81572868, -0.81966765, -0.82580199, -0.83511534],
        [-0.82485699, -0.82883834, -0.8362085 , -0.84494163],
        [-0.83496124, -0.83509975, -0.83603291, -0.84484426],
        [-0.83481096, -0.83640177, -0.83462448, -0.84279185],
        [-0.83290099, -0.83567633, -0.83343734, -0.84163708],
        [-0.83158729, -0.83429571, -0.83214391, -0.84116895],
        [-0.83100444, -0.83247886, -0.83427135, -0.83817006],
        [-0.82802254, -0.830982  , -0.84260906, -0.84685789]],

       [[-0.82926863, -0.82535357, -0.81962295, -0.82940024],
        [-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
        [-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
        [-0.83496124, -0.83509975, -0.83603291, -0.84484426],
        [-0.84479157, -0.84125479, -0.83585723, -0.84474687],
        [-0.84464544, -0.84253437, -0.83444812, -0.84269387],
        [-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
        [-0.84151027, -0.84046456, -0.83196635, -0.84107051],
        [-0.84094331, -0.83867874, -0.83409482, -0.83807077],
        [-0.83804212, -0.83720728, -0.84243663, -0.84676108]],

       [[-0.82911602, -0.82669005, -0.81815252, -0.82725751],
        [-0.82926995, -0.82692302, -0.81963519, -0.82401759],
        [-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
        [-0.83481096, -0.83640177, -0.83462448, -0.84279185],
        [-0.84464544, -0.84253437, -0.83444812, -0.84269387],
        [-0.84449925, -0.84380921, -0.83303354, -0.84062854],
        [-0.84264105, -0.84309893, -0.83184123, -0.83946655],
        [-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
        [-0.84079554, -0.83996778, -0.83267886, -0.83597806],
        [-0.83789312, -0.83850169, -0.84105351, -0.84472027]],

       [[-0.82717653, -0.82594539, -0.81691338, -0.82605227],
        [-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
        [-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
        [-0.83290099, -0.83567633, -0.83343734, -0.84163708],
        [-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
        [-0.84264105, -0.84309893, -0.83184123, -0.83946655],
        [-0.84077276, -0.84238718, -0.83064506, -0.83830071],
        [-0.83948755, -0.84103251, -0.82934186, -0.83782811],
        [-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
        [-0.8359994 , -0.8377805 , -0.83988758, -0.84357199]],

       [[-0.82584266, -0.82452835, -0.81556359, -0.82556375],
        [-0.82599791, -0.82476263, -0.81705573, -0.82230957],
        [-0.82261689, -0.82525665, -0.82162309, -0.83123616],
        [-0.83158729, -0.83429571, -0.83214391, -0.84116895],
        [-0.84151027, -0.84046456, -0.83196635, -0.84107051],
        [-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
        [-0.83948755, -0.84103251, -0.82934186, -0.83782811],
        [-0.83819763, -0.83967254, -0.82803413, -0.83735488],
        [-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
        [-0.83469681, -0.83640794, -0.83861716, -0.84310649]],

       [[-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
        [-0.82540639, -0.82289927, -0.81926792, -0.81915487],
        [-0.82202021, -0.82339567, -0.82381013, -0.82815378],
        [-0.83100444, -0.83247886, -0.83427135, -0.83817006],
        [-0.84094331, -0.83867874, -0.83409482, -0.83807077],
        [-0.84079554, -0.83996778, -0.83267886, -0.83597806],
        [-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
        [-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
        [-0.83705204, -0.83608379, -0.83232385, -0.83126675],
        [-0.83411887, -0.83460162, -0.8407067 , -0.84012427]],

       [[-0.82222369, -0.82112803, -0.82649334, -0.83150323],
        [-0.8223804 , -0.82136435, -0.82794484, -0.82829943],
        [-0.818968  , -0.82186268, -0.83238638, -0.83708633],
        [-0.82802254, -0.830982  , -0.84260906, -0.84685789],
        [-0.83804212, -0.83720728, -0.84243663, -0.84676108],
        [-0.83789312, -0.83850169, -0.84105351, -0.84472027],
        [-0.8359994 , -0.8377805 , -0.83988758, -0.84357199],
        [-0.83469681, -0.83640794, -0.83861716, -0.84310649],
        [-0.83411887, -0.83460162, -0.8407067 , -0.84012427],
        [-0.83116192, -0.83311339, -0.84889275, -0.84876322]]])
(10, 10, 4)

display(trainY[0:5])
display(trainY.shape)
---
array([0, 1, 0, 1, 0], dtype=int64)
(47666,)

Upvotes: 0

Views: 996

Answers (2)

Nivesh Gadipudi
Nivesh Gadipudi

Reputation: 506

The essence of neural networks is to induce non-linearities, You didn't mention any activation functions in the first three convolution layers.

Refer this for different activation functions

tf.keras.layers.Conv2D(64, 3, activation = 'relu')

Upvotes: 0

Mostafa Labib
Mostafa Labib

Reputation: 809

The model isn't learning because the convolution layers have linear activation function which is None by default if you don't specify one. Usually the activation function used with conv layers is Relu so simply add activation='relu' to your conv layers

Upvotes: 3

Related Questions