Richard Yao
Richard Yao

Reputation: 41

Keras with Tensorflow backend---Memoryerror in model.fit() with checkpoint callbacks

I'm trying to train an autoencoder. It keeps getting Memoryerror from Keras at model.fit(), it always occurs when i add validation-related parameters to model.fit like validation_split.

Error:

Traceback (most recent call last):
  File "/root/abnormal-spatiotemporal-ae/start_train.py", line 53, in <module>
    train(dataset=dataset, job_folder=job_folder, logger=logger)
  File "/root/abnormal-spatiotemporal-ae/classifier.py", line 109, in train
    callbacks=[snapshot, earlystop, history_log]
  File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/keras/engine/training.py",

line 990, in fit y, val_y = (slice_arrays(y, 0, split_at), File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 528, in slice_arrays return [None if x is None else x[start:stop] for x in arrays] File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 528, in return [None if x is None else x[start:stop] for x in arrays] File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/keras/utils/io_utils.py", line 110, in getitem return self.data[idx] File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/_hl/dataset.py", line 485, in getitem arr = numpy.ndarray(mshape, new_dtype, order='C') MemoryError

Code:

data = HDF5Matrix(os.path.join(video_root_path, '{0}/{0}_train_t{1}.h5'.format(dataset, time_length)),
                  'data')

snapshot = ModelCheckpoint(os.path.join(job_folder,
           'model_snapshot_e{epoch:03d}_{val_loss:.6f}.h5'))
earlystop = EarlyStopping(patience=10)
history_log = LossHistory(job_folder=job_folder, logger=logger)

logger.info("Initializing training...")

history = model.fit(
    data,
    data,
    batch_size=batch_size,
    epochs=nb_epoch,
    validation_split=0.15,
    shuffle='batch',
    callbacks=[snapshot, earlystop, history_log]
)

The code will run correctly when i remove validation_split=0.15 in model.fit and snapshot in callbacks.

data variable contains all processed images from training dataset, its shape is (15200, 8, 224, 224, 1) and size is 6101401600 This code is used on computer with 64GB RAM and a Tesla P100, no worry for memory space, and my python is 64-bit

Model:

input_tensor = Input(shape=(t, 224, 224, 1))

    conv1 = TimeDistributed(Conv2D(128, kernel_size=(11, 11), padding='same', strides=(4, 4), name='conv1'),
                            input_shape=(t, 224, 224, 1))(input_tensor)
    conv1 = TimeDistributed(BatchNormalization())(conv1)
    conv1 = TimeDistributed(Activation('relu'))(conv1)

    conv2 = TimeDistributed(Conv2D(64, kernel_size=(5, 5), padding='same', strides=(2, 2), name='conv2'))(conv1)
    conv2 = TimeDistributed(BatchNormalization())(conv2)
    conv2 = TimeDistributed(Activation('relu'))(conv2)

    convlstm1 = ConvLSTM2D(64, kernel_size=(3, 3), padding='same', return_sequences=True, name='convlstm1')(conv2)
    convlstm2 = ConvLSTM2D(32, kernel_size=(3, 3), padding='same', return_sequences=True, name='convlstm2')(convlstm1)
    convlstm3 = ConvLSTM2D(64, kernel_size=(3, 3), padding='same', return_sequences=True, name='convlstm3')(convlstm2)

    deconv1 = TimeDistributed(Conv2DTranspose(128, kernel_size=(5, 5), padding='same', strides=(2, 2), name='deconv1'))(convlstm3)
    deconv1 = TimeDistributed(BatchNormalization())(deconv1)
    deconv1 = TimeDistributed(Activation('relu'))(deconv1)

    decoded = TimeDistributed(Conv2DTranspose(1, kernel_size=(11, 11), padding='same', strides=(4, 4), name='deconv2'))(
        deconv1)

Upvotes: 0

Views: 826

Answers (1)

Martin
Martin

Reputation: 644

This question faced the same problem. Here the explanation was, that there were too much data points before the flattening layer. This caused the RAM to overflow. This was solved by adding additional convolution layer.

Upvotes: 1

Related Questions