user62039
user62039

Reputation: 371

How to limit RAM usage while batch training in tensorflow?

I am training a deep neural network with a large image dataset in mini-batches of size 40. My dataset is in .mat format (which I can easily change to any other format e.g. .npy format if necessitates) and before training, loaded as a 4-D numpy array. My problem is that while training, cpu-RAM (not GPU RAM) is very quickly exhausting and starts using almost half of my Swap memory.

My training code has the following pattern:

batch_size = 40
...
with h5py.File('traindata.mat', 'r') as _data:
    train_imgs = np.array(_data['train_imgs'])

# I can replace above with below loading, if necessary
# train_imgs = np.load('traindata.npy')

...

shape_4d = train_imgs.shape 
for epoch_i in range(max_epochs):
    for iter in range(shape_4d[0] // batch_size):
        y_ = train_imgs[iter*batch_size:(iter+1)*batch_size]
        ...
        ...

This seems like the initial loading of the full training data is itself becoming the bottle-neck (taking over 12 GB cpu RAM before I abort).

What is the best efficient way to tackle this bottle-neck?

Thanks in advance.

Upvotes: 1

Views: 5667

Answers (1)

jorgemf
jorgemf

Reputation: 1143

Loading a big dataset in memory is not a good idea. I suggest you to use something different for loading the datasets, take a look to the dataset API in TensorFlow: https://www.tensorflow.org/programmers_guide/datasets

You might need to convert your data into other format, but if you have a CSV or TXT file with a example per line you can use TextLineDataset and feed the model with it:

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

def _parse_py_fun(text_line):
    ... your custom code here, return np arrays

def _map_fun(text_line):
    result = tf.py_func(_parse_py_fun, [text_line], [tf.uint8])
    ... other tensorlow code here
    return result

dataset = dataset.map(_map_fun)
dataset = dataset.batch(4)
iterator = dataset.make_one_shot_iterator()
input_data_of_your_model = iterator.get_next()

output = build_model_fn(input_data_of_your_model)

sess.run([output]) # the input was assigned directly when creating the model

Upvotes: 3

Related Questions