saru
saru

Reputation: 255

tf.data pipeline design for optimized performance

I am new to TensorFlow and I would like to know of there is any specific order to set the dataset using tfdata. For example:

    data_files = tf.gfile.Glob("%s%s%s" % ("./data/cifar-100-binary/", self.data_key, ".bin"))
    data = tf.data.FixedLengthRecordDataset(data_files, record_bytes=3074)
    data = data.map(self.load_transform)
    if self.shuffle_key:
        data = data.shuffle(5000)

    data = data.batch(self.batch_size).repeat(100)
    iterator = data.make_one_shot_iterator()
    img, label = iterator.get_next()
    # label = tf.one_hot(label, depth=100)
    print('img_shape:', img.shape)

In this case I read the data then shuffle the data followed by batch and repeat specifications. With this method my computer's RAM increases by 2%

and then I tried one more method:

    data_files = tf.gfile.Glob("%s%s%s" % ("./data/cifar-100-binary/", self.data_key, ".bin"))
    data = tf.data.FixedLengthRecordDataset(data_files, record_bytes=3074)
    data = data.map(self.load_transform)
    data = data.batch(self.batch_size).repeat(100)
    if self.shuffle_key:
        data = data.shuffle(5000)
    iterator = data.make_one_shot_iterator()
    img, label = iterator.get_next()
    # label = tf.one_hot(label, depth=100)
    print('img_shape:', img.shape)

so with this case when I first specify the batch size, repeat and then shuffle the RAM utilization increases by 40% (I do not know why) it would be great if someone helps me figure that out. So is there a sequence which I should always follow to define the dataset in tensorflow using tf.data ?

Upvotes: 3

Views: 451

Answers (1)

AAudibert
AAudibert

Reputation: 1273

The memory usage increases because you are shuffling batches instead of single records.

data.shuffle(5000) will fill a buffer of 5000 elements, then sample randomly from the buffer to produce the next element.

data.batch(self.batch_size) changes the element type from single records to batches of records. So if you call batch before shuffle, the shuffle buffer will contain 5000 * self.batch_size records instead of just 5000.

The order of calling shuffle and batch will also affect the data itself. Batching before shuffling will result in all elements of a batch being sequential.

batch before shuffle:

>>> dataset = tf.data.Dataset.range(12)
>>> dataset = dataset.batch(3)
>>> dataset = dataset.shuffle(4)
>>> print([element.numpy() for element in dataset])
[array([ 9, 10, 11]), array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8])]

shuffle before batch:

>>> dataset = tf.data.Dataset.range(12)
>>> dataset = dataset.shuffle(4)
>>> dataset = dataset.batch(3)
>>> print([element.numpy() for element in dataset])
[array([1, 2, 5]), array([4, 7, 8]), array([0, 3, 9]), array([ 6, 10, 11])]

Usually shuffling is done before batching to avoid all elements in the batch being sequential.

Upvotes: 1

Related Questions