Karnivaurus
Karnivaurus

Reputation: 24121

TensorFlow takes too long to load data into a tf.Dataset

I am using TensorFlow 1.9 to train an image dataset, which is too big to load from my hard drive into RAM. Therefore, I have split the dataset into two halves on my hard drive. I want to know what is the most efficient way to train on the entire dataset.

My GPU has 3 GB of memory, and my RAM has 32 GB of memory. The size of each half dataset is 20 GB. My hard drive has plenty of free space (over 1 TB).

My attempt is as follows. I create an initializable tf.Dataset, and then on every epoch, I initialize it twice: once for each of the halves of the dataset. In this way, each epoch sees the entire dataset, but only has to have half of it loaded in RAM at any one time.

However, this is very slow, because it takes a long time to load the data from my hard drive, and also quite a long time to initialize the dataset with this data each time.

Is there a more efficient way to do this?

I have tried training on each half of the dataset for multiple epochs before loading the other half of the dataset, which is much faster, but this gives much worse performance on the validation data. Presumably, this is because the model is overfitting on each half and then not generalising to the data in the other half.

In my code below, I create and save some test data, which is then loaded as described above. The time to load each half dataset is about 5 seconds, and the time to initialize the dataset with this data is about 1 second. This may only seem like small amounts, but it all adds up over multiple epochs. In fact, my computer spends almost as much time loading the data as it does actually training on the data.

import tensorflow as tf
import numpy as np
import time

# Create and save 2 datasets of test NumPy data
dataset_num_elements = 100000
element_dim = 10000
batch_size = 50
test_data = np.zeros([2, int(dataset_num_elements * 0.5), element_dim], dtype=np.float32)
np.savez('test_data_1.npz', x=test_data[0])
np.savez('test_data_2.npz', x=test_data[1])

# Create the TensorFlow dataset
data_placeholder = tf.placeholder(tf.float32, [int(dataset_num_elements * 0.5), element_dim])
dataset = tf.data.Dataset.from_tensor_slices(data_placeholder)
dataset = dataset.shuffle(buffer_size=dataset_num_elements)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer

num_batches = int(dataset_num_elements / batch_size)

with tf.Session() as sess:
    while True:
        for dataset_section in range(2):
            # Load the data from the hard drive
            t1 = time.time()
            print('Loading')
            loaded_data = np.load('test_data_' + str(dataset_section + 1) + '.npz')
            x = loaded_data['x']
            print('Loaded')
            t2 = time.time()
            loading_time = t2 - t1
            print('Loading time = ' + str(loading_time))
            # Initialize the dataset with this loaded data
            t1 = time.time()
            sess.run(init_op, feed_dict={data_placeholder: x})
            t2 = time.time()
            initialization_time = t2 - t1
            print('Initialization time = ' + str(initialization_time))
            # Read the data in batches
            for i in range(num_batches):
                x = sess.run(next_element)

Upvotes: 2

Views: 3089

Answers (1)

Kaihong Zhang
Kaihong Zhang

Reputation: 419

Feed is not an efficient way to input data. You can input data like this:

  1. create a filename dataset containing all the input file names. you can shuffle, repeat the dataset here.
  2. map this dataset to data, map function is to read, decode, transform image. Use multi-thread for the map convert.
  3. prefetch the data to train.

This is just an example way. You could design your own pipeline, remember the following:

  • use lightweight feed as possible
  • use multi-thread to read and preprocess
  • prefetch data for training

Upvotes: 5

Related Questions