user6257803
user6257803

Reputation:

How do I train a neural network in Keras on data stored in HDF5 files?

I have two fairly large PyTables EArrays which contain the inputs and labels for a regression task. The input array is 4d (55k x 128 x 128 x 3) and the label array is 1d (55k). I have a NN architecture specified in Keras which I want to train on this data, but there are two problems.

  1. The input array at least is too large to fit in memory at once.
  2. I only want to train on some random subset of the full data, since I want to take train, test, and validation splits. I select the splits by slicing on random subsets of the indices.

How can I select subsets of the HDF5 arrays (input and output) according to train/test indices and train on the training subsets, without reading them into memory all at once? Is there some way to create a "view" of the on-disk array that can be sliced and that Keras will see as a regular NumPy ndarray?

What I've tried so far is to convert my arrays to Keras HDF5Matrix objects (with e.g. X = keras.utils.io_utils.HDF5Matrix(X)), but when I then slice this to get a training split, the full slice (80% of the full array) gets put into memory, which gives me a MemoryError.

Upvotes: 5

Views: 2652

Answers (1)

bogatron
bogatron

Reputation: 19179

You can use the fit_generator method of your keras model. Just write your own generator class/function that pulls random batches of samples from your HDF5 file. That way, you never have to have all the data in memory at once. Similarly, if your validation data are too large to fit in memory, the validation_data argument to fit_generator also accepts a generator that produces batches from your validation data.

Essentially, you just need to do an np.random.shuffle on an array of indices into your data set, then split the random index array into training, validation, and testing array indices. Your generator arguments to fit_generator will just pull batches from your HDF5 file according to sequential batches of indices in the training and validation index arrays.

Upvotes: 5

Related Questions