Reputation: 255
I am new to TensorFlow and I would like to know of there is any specific order to set the dataset using tfdata. For example:
data_files = tf.gfile.Glob("%s%s%s" % ("./data/cifar-100-binary/", self.data_key, ".bin"))
data = tf.data.FixedLengthRecordDataset(data_files, record_bytes=3074)
data = data.map(self.load_transform)
if self.shuffle_key:
data = data.shuffle(5000)
data = data.batch(self.batch_size).repeat(100)
iterator = data.make_one_shot_iterator()
img, label = iterator.get_next()
# label = tf.one_hot(label, depth=100)
print('img_shape:', img.shape)
In this case I read the data then shuffle the data followed by batch and repeat specifications. With this method my computer's RAM increases by 2%
and then I tried one more method:
data_files = tf.gfile.Glob("%s%s%s" % ("./data/cifar-100-binary/", self.data_key, ".bin"))
data = tf.data.FixedLengthRecordDataset(data_files, record_bytes=3074)
data = data.map(self.load_transform)
data = data.batch(self.batch_size).repeat(100)
if self.shuffle_key:
data = data.shuffle(5000)
iterator = data.make_one_shot_iterator()
img, label = iterator.get_next()
# label = tf.one_hot(label, depth=100)
print('img_shape:', img.shape)
so with this case when I first specify the batch size, repeat and then shuffle the RAM utilization increases by 40% (I do not know why) it would be great if someone helps me figure that out. So is there a sequence which I should always follow to define the dataset in tensorflow using tf.data ?
Upvotes: 3
Views: 451
Reputation: 1273
The memory usage increases because you are shuffling batches instead of single records.
data.shuffle(5000)
will fill a buffer of 5000
elements, then sample randomly from the buffer to produce the next element.
data.batch(self.batch_size)
changes the element type from single records to batches of records. So if you call batch
before shuffle
, the shuffle buffer will contain 5000 * self.batch_size
records instead of just 5000
.
The order of calling shuffle
and batch
will also affect the data itself. Batching before shuffling will result in all elements of a batch being sequential.
batch
before shuffle
:
>>> dataset = tf.data.Dataset.range(12)
>>> dataset = dataset.batch(3)
>>> dataset = dataset.shuffle(4)
>>> print([element.numpy() for element in dataset])
[array([ 9, 10, 11]), array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8])]
shuffle
before batch
:
>>> dataset = tf.data.Dataset.range(12)
>>> dataset = dataset.shuffle(4)
>>> dataset = dataset.batch(3)
>>> print([element.numpy() for element in dataset])
[array([1, 2, 5]), array([4, 7, 8]), array([0, 3, 9]), array([ 6, 10, 11])]
Usually shuffling is done before batching to avoid all elements in the batch being sequential.
Upvotes: 1