Reputation: 21
Let's consider a simple CNN, trained on 1000 images of fixed input_shape = (28,28,1)
Surprisingly, it will allow predicting images of different shape such as (28,30,1)
Shouldn't it fail instead of silently predict? If not, why? (reproducable code below)
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
X_train = tf.ones((1000,28,28,1))
y_train = tf.ones((1000,10))
model = models.Sequential()
model.add(layers.Conv2D(8, (4,4), input_shape=(28, 28, 1), activation='relu', padding='same'))
model.add(layers.MaxPool2D(pool_size=(2,2)))
model.add(layers.Conv2D(16, (3,3), activation='relu', padding='same'))
model.add(layers.MaxPool2D(pool_size=(2,2)))
model.add(layers.Flatten())
model.add(layers.Dense(10, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train,y_train)
model.layers[0]._build_input_shape # >>>TensorShape([None, 28, 28, 1])
Once fitted, it can also (surprisingly) predict images of wrong_input. Wouldn't it be better for it to fail?
X_wrong = tf.ones((1000,28,30,1))
model.predict(X_wrong).shape # works, and return (1000, 10)
model.layers[0](X_wrong).shape # also works (1000, 28, 32, 8)
I understand that prediction does works in principle because numbers of weights still match, thanks to maxpooling:
x = model.layers[0](X_wrong)
print(x.shape)
x = model.layers[1](x)
print(x.shape)
x = model.layers[2](x)
print(x.shape)
x = model.layers[3](x)
print(x.shape)
x = model.layers[4](x)
print(x.shape)
x = model.layers[5](x)
print(x.shape)
>>>(1000, 28, 30, 8)
>>>(1000, 14, 15, 8)
>>>(1000, 14, 15, 16)
>>>(1000, 7, 7, 16)
>>>(1000, 784)
>>>(1000, 10)
Increasing shape a bit more will fail for good, as expected because 896 != 784
X_wrong_fail = tf.ones((1000,28,32,1))
>>>(1000, 28, 32, 8)
>>>(1000, 14, 16, 8)
>>>(1000, 14, 16, 16)
>>>(1000, 7, 8, 16)
>>>(1000, 896)
>>> ValueError: Input 0 of layer dense_2 is incompatible with the layer: expected axis -1 of input shape to have value 784 but received input with shape (1000, 896)
But it would have been better to fail in both cases imo. Don't you agree?
Upvotes: 2
Views: 114
Reputation: 21
I fully agree that there should be a warning at least, sorry I didn't fully read your post and already did a writeup about weights matching, I'll still post it if someone wants to get more insight tho.
Edit: Did you consider changing the padding to "valid"? I believe that would prevent your problem
Edit2: Apparently it doesn't
Edit3: Setting the padding to "same" on max pooling and conv will resolve your problem
You encounter this problem because of how kernels and max pooling layers work:
1.Convolutional layers don't care about the shape, they don't have weights, they just need to apply the filters
2.Max pooling layers con't care about shape either, they just reduce the dimentions
Dense layer has 10 neurons and each neuron has 784 weights in your example, so as long as flatten is going to output 784 values the network is fine and can process the input You have happened to find that 28x28 and 30x30 produce the same output shape in flatten, however 28x32 does not, there is not enough weights on the neurons and the Dense layer crashes.
As a fellow ml dev, I really recomment print(model.summary())
when having doubts about shapes :)
Model for (28, 28, 8):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 8) 136
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 8) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 16) 1168
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 16) 0
_________________________________________________________________
flatten (Flatten) (None, 784) 0 <--- the shape here matters
_________________________________________________________________
dense (Dense) (None, 10) 7850
_________________________________________________________________
dense_1 (Dense) (None, 10) 110
=================================================================
Total params: 9,264
Trainable params: 9,264
Non-trainable params: 0
_________________________________________________________________
Model for (30, 30, 8):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 30, 30, 8) 136
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 8) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 15, 15, 16) 1168
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 16) 0
_________________________________________________________________
flatten (Flatten) (None, 784) 0 <--- the shape here matters
_________________________________________________________________
dense (Dense) (None, 10) 7850
_________________________________________________________________
dense_1 (Dense) (None, 10) 110
=================================================================
Total params: 9,264`enter code here`
Trainable params: 9,264
Non-trainable params: 0
_________________________________________________________________
Model for (28, 32, 8):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 32, 8) 136
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 16, 8) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 16, 16) 1168
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 8, 16) 0
_________________________________________________________________
flatten (Flatten) (None, 896) 0 <--- different shape now
_________________________________________________________________
dense (Dense) (None, 10) 8970
_________________________________________________________________
dense_1 (Dense) (None, 10) 110
=================================================================
Total params: 10,384
Trainable params: 10,384
Non-trainable params: 0
_________________________________________________________________
Upvotes: 2