shallow_water
shallow_water

Reputation: 131

keras batchnorm has awful test performance

During cross-validation on training data, use of batchnorm significantly improves performance. But (after retraining on entire training set) the presence of the batchnorm layer completely destroys the model's generalization to a holdout set. This is a little surprising, and I'm wondering if I'm implementing the test predictions incorrectly.

Generalization w/o the batchnorm layer present is fine (not high enough for my project's goals, but reasonable for such a simple net).

I cannot share my data, but does anyone see an obvious implementation error? Is there a flag that should be set to test mode? I can't find an answer in the docs, and dropout (which also should have different train/test behavior) works as expected. Thanks!

code:

from keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=10)
from keras.callbacks import ModelCheckpoint
filepath="L1_batch1_weights.best.hdf5" 
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')

init = 'he_normal'

act = 'relu'

neurons1 = 80

dropout_rate = 0.5

model = Sequential()
model.add(Dropout(0.2, input_shape=(5000,)))
model.add(Dense(neurons1))
model.add(BatchNormalization())
model.add(Activation(act))
model.add(Dropout(dropout_rate)) 
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer="adam", metrics=["accuracy"])


my_model = model.fit(X_train, y_train, batch_size=128, nb_epoch=150, validation_data =(X_test, y_test),callbacks=[early_stopping, checkpoint]) 

model.load_weights("L1_batch1_weights.best.hdf5")

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file")

probs = model.predict_proba(X_test,batch_size=2925)
fpr, tpr, thresholds = roc_curve(y_test, probs)

Upvotes: 2

Views: 3233

Answers (1)

shallow_water
shallow_water

Reputation: 131

From the docs: "During training we use per-batch statistics to normalize the data, and during testing we use running averages computed during the training phase."

In my case training batch size was 128. At test time, I had manually set the batch size to the size of the complete test set (2925).

It makes sense that the statistics used for one batch size will obviously not be relevant to a batch size that is significantly different.

Changing the test batch size to the train batch size (128) produced more stable results. I played w/prediction batch sizes to observe the effects: prediction results were stable for any batch size +/- 3x of the training batch size, beyond that performance deteriorated.

There is some discussion of the impact of test batch size along with using batchnorm when used with load_weights() here: https://github.com/fchollet/keras/issues/3423

Upvotes: 1

Related Questions