Reputation: 225
I'm training a simple adversarial image to break a pretrained model. However, the result I obtained during the fit() process is different from calling predict() on the same input (constant input).
model.trainable = False
gan = Sequential()
gan.add(Dense( 256 * 256 * 3, use_bias=False, input_shape=(1,)))
gan.add(Reshape((256, 256, 3)))
gan.add(model)
gan.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_2 (Dense) (None, 196608) 196608
_________________________________________________________________
reshape_2 (Reshape) (None, 256, 256, 3) 0
_________________________________________________________________
sequential_1 (Sequential) (None, 2) 24952610
=================================================================
Total params: 25,149,218
Trainable params: 196,608
Non-trainable params: 24,952,610
_________________________________________________________________
img = img.reshape(256, 256, 3)
def custom_loss(layer):
# Create a loss function that adds the MSE loss to the mean of all squared activations of a specific layer
def loss(y_true,y_pred):
y_true = K.print_tensor(y_true, message='y_true = ')
y_pred = K.print_tensor(y_pred, message='y_pred = ')
label_diff = K.square(y_pred - y_true)
return K.mean(label_diff)
# Return a function
return loss
gan.compile(optimizer='adam',
loss=custom_loss(gan.layers[1]), # Call the loss function with the selected layer
metrics=['accuracy'])
x = np.ones((1,1))
goal = np.array([0, 1])
y = goal.reshape((1,2))
gan.fit(x, y, epochs=300, verbose=1)
During fit(), the loss is decreasing nicely
Epoch 1/300
1/1 [==============================] - 5s 5s/step - loss: 0.9950 - acc: 0.0000e+00
...
Epoch 300/300
1/1 [==============================] - 0s 46ms/step - loss: 0.0045 - acc: 1.0000
In the backend, the y_pred and y_true were also correct
......
y_true = [[0 1]]
y_pred = [[0.100334756 0.899665236]]
y_true = [[0 1]]
y_pred = [[0.116679631 0.883320332]]
y_true = [[0 1]]
y_pred = [[0.0832592845 0.916740656]]
y_true = [[0 1]]
y_pred = [[0.098835744 0.901164234]]
y_true = [[0 1]]
y_pred = [[0.0979194269 0.902080595]]
y_true = [[0 1]]
y_pred = [[0.057831794 0.942168236]]
y_true = [[0 1]]y_pred = [[0.0760448873 0.923955142]]
y_true = [[0 1]]
y_pred = [[0.041532293 0.958467722]]
y_true = [[0 1]]
y_pred = [[0.0667938739 0.933206141]]
print(gan.predict(x))
Gives
[[0.99923825 0.00076174]]
Tried with both pretrained Resnet and InceptionV3 and both are experiencing the same problem. Attached is model.summary()
For Inception:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
inception_v3 (Model) (None, None, None, 2048) 21802784
_________________________________________________________________
global_average_pooling2d_1 ( (None, 2048) 0
_________________________________________________________________
dense_1 (Dense) (None, 1024) 2098176
_________________________________________________________________
dropout_1 (Dropout) (None, 1024) 0
_________________________________________________________________
dense_2 (Dense) (None, 1024) 1049600
_________________________________________________________________
dropout_2 (Dropout) (None, 1024) 0
_________________________________________________________________
dense_3 (Dense) (None, 2) 2050
=================================================================
Total params: 24,952,610
Trainable params: 14,264,706
Non-trainable params: 10,687,904
_________________________________________________________________
For Resnet:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 262, 262, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 128, 128, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 128, 128, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 128, 128, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D) (None, 130, 130, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 64, 64, 64) 0 pool1_pad[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 64, 64, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 64, 64, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 64, 64, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 64, 64, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 64, 64, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 64, 64, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 64, 64, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 64, 64, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 64, 64, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 64, 64, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 64, 64, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 64, 64, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 64, 64, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 64, 64, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 64, 64, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 64, 64, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 64, 64, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 64, 64, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 64, 64, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 64, 64, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 64, 64, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 64, 64, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 64, 64, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 64, 64, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 64, 64, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 64, 64, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 64, 64, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 64, 64, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 64, 64, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 64, 64, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 64, 64, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 64, 64, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 32, 32, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 32, 32, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 32, 32, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 32, 32, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 32, 32, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 32, 32, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 32, 32, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 32, 32, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 32, 32, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 32, 32, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 32, 32, 512) 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 32, 32, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 32, 32, 128) 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 32, 32, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 32, 32, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 32, 32, 512) 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 32, 32, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 32, 32, 128) 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 32, 32, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 32, 32, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 32, 32, 128) 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 32, 32, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 32, 32, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 32, 32, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 32, 32, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 32, 32, 512) 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 32, 32, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 16, 16, 256) 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 16, 16, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 16, 16, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 16, 16, 256) 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 16, 16, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 16, 16, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 16, 16, 1024) 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 16, 16, 1024) 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 16, 16, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 16, 16, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 16, 16, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 16, 16, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 16, 16, 256) 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 16, 16, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 16, 16, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 16, 16, 256) 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 16, 16, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 16, 16, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 16, 16, 1024) 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 16, 16, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 16, 16, 1024) 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 16, 16, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 16, 16, 256) 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 16, 16, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 16, 16, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 16, 16, 256) 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 16, 16, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 16, 16, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 16, 16, 1024) 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 16, 16, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 16, 16, 1024) 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 16, 16, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 16, 16, 256) 262400 activation_31[0][0]
__________________________________________________________________________________________________
... omitted ...
Total params: 23,593,859
Trainable params: 23,540,739
Non-trainable params: 53,120
__________________________________________________________________________________________________
Upvotes: 0
Views: 820
Reputation: 86600
Those pretrained models contain BatchNormalization
layers.
It's expected that they perform differently between train and test (this is also true for Dropout
layers, but the differences would not be so drastic).
A BatchNormalization
during training will use the mean and variance from the current batch to do normalization, it will also apply some statistic compensation for the fact that one batch may not be representative of the full dataset.
But during evaluation, the BatchNormalization
will use the adjusted values that were gathered during training for mean and variation. (In this case, gathered during "pretraining" not your training)
For BatchNormalization
to work correctly, you need that the inputs to the pretrained model are in the same range as the model's original training data. Otherwise you have to leave the BatchNormalization
layers trainable so the mean and variance adjust for your data.
But your training needs significant batch sizes as well as real data to train properly.
Hints for training images.
In the same module where you import the pretrained model, you can import the preprocess_input
function. Give it some images loaded with keras.preprocessing.images.load_img
and see what is the model's expected range.
When using ImageDataGenerator
, you can pass this preprocess_input
function so the generator gives you expected data.
Upvotes: 4