Himaprasoon
Himaprasoon

Reputation: 2659

How to change batch size dynamically in Tensorflow 2.0 Dataset?

In TensorFlow 1.X you could change the batch size dynamically using a placeholder. eg

dataset.batch(batch_size=tf.placeholder())
See full example

How do you do it in TensorFlow 2.0?

I have tried the following but it doesn't work.

import numpy as np
import tensorflow as tf


def new_gen_function():
    for i in range(100):
        yield np.ones(2).astype(np.float32)


batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
    batch_size=batch_size)

for data in train_ds:
    print(data.shape[0])
    batch_size.assign(10)
    print(batch_size)

Output

5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...

I am training a model using a custom training loop using Gradient tape. How can I achieve this?

Upvotes: 11

Views: 9163

Answers (3)

mdaoust
mdaoust

Reputation: 6367

Obviously if you're using .from_generator, you can manually batch things in there, but that doesn't really address your question.

The two easiest ways I can think of are to include the batch size as a component of the dataset, and then build batches of the requested size:

import tensorflow as tf

batch_sizes = tf.data.Dataset.range(4)
ds = batch_sizes.map(lambda n: tf.random.normal(shape=[n,3]))

for item in ds:
  print(item.shape)
  print()
(0,3)
(1,3)
(2,3)
(3,3)

Or, building on @PG-N's solution, if you need a version that runs totally inside a tf.function you can pack them using tf.TensorArray:

import tensorflow as tf 

 class Batcher(tf.Module):
   def __init__(self, ds, batch_size=0):   
     self.it = iter(ds) 
     self._batch_size = tf.Variable(batch_size) 

   @property 
   def batch_size(self): 
     return self._batch_size

   @batch_size.setter  
   def batch_size(self, new_size): 
     self._batch_size.assign(new_size) 

   @tf.function 
   def __call__(self): 
     examples =tf.TensorArray(dtype=tf.int64, size=self.batch_size) 
     for i in tf.range(self.batch_size): 
       examples = examples.write(i, next(self.it)) 

     return examples.stack() 
ds = tf.data.Dataset.range(100)
B = Batcher(ds)
B.batch_size = 5
B().numpy()
array([0, 1, 2, 3, 4])
B.batch_size = 10
B().numpy()
array([ 5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
B.batch_size = 3
B().numpy()
array([ 15,  16,  17])

You could probably do something in the middle using tf.nest to make this general to datasets with more than just a single tensor component.

Also, depending on your use-case, methods like group_by_window, and bucket_by_sequence_length could be helpful. These do some multi-size batches, they could be what you're looking for, or the implementation could be a clue for your problem.

Upvotes: 1

P-Gn
P-Gn

Reputation: 24581

I don't think you can the way you used to in TF1.

A work-around could be to build the batch yourself by stacking individual samples:

import tensorflow as tf

ds = tf.data.Dataset.range(10).repeat()
iterator = iter(ds)
for batch_size in range(1, 10):
  batch = tf.stack([iterator.next() for _ in range(batch_size)], axis=0)
  print(batch)

# tf.Tensor([0], shape=(1,), dtype=int64)
# tf.Tensor([1 2], shape=(2,), dtype=int64)
# tf.Tensor([3 4 5], shape=(3,), dtype=int64)
# tf.Tensor([6 7 8 9], shape=(4,), dtype=int64)
# tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
# tf.Tensor([5 6 7 8 9 0], shape=(6,), dtype=int64)
# tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int64)
# tf.Tensor([8 9 0 1 2 3 4 5], shape=(8,), dtype=int64)
# tf.Tensor([6 7 8 9 0 1 2 3 4], shape=(9,), dtype=int64)

Upvotes: 4

AlexisBRENON
AlexisBRENON

Reputation: 3079

From what I know, you should instantiate a new dataset iterator to make your change take effect. This will require to tweak a little bit to skip already seen samples.

Here is my simplest solution:

import numpy as np
import tensorflow as tf

def get_dataset(batch_size, num_samples_seen):
    return tf.data.Dataset.range(
        100
    ).skip(
        num_samples_seen
    ).batch(
        batch_size=batch_size
    )

def main():
    batch_size = 1
    num_samples_seen = 0

    train_ds = get_dataset(batch_size, num_samples_seen)

    ds_iterator = iter(train_ds)
    while True:
        try:
            data = next(ds_iterator)
        except StopIteration:
            print("End of iteration")
            break

        print(data)
        batch_size *= 2
        num_samples_seen += data.shape[0]
        ds_iterator = iter(get_dataset(batch_size, num_samples_seen))
        print("New batch size:", batch_size)

if __name__ == "__main__":
    main()

As you can see here, you have to instantiate a new dataset (through a call to get_dataset) and update the iterator.

I don't know of the performance impact of such a solution. Maybe there is another solution requiring to "just" instantiate a batch step instead of the whole dataset.

Upvotes: 1

Related Questions