Reputation: 2659
In TensorFlow 1.X you could change the batch size dynamically using a placeholder. eg
dataset.batch(batch_size=tf.placeholder())
See full example
How do you do it in TensorFlow 2.0?
I have tried the following but it doesn't work.
import numpy as np
import tensorflow as tf
def new_gen_function():
for i in range(100):
yield np.ones(2).astype(np.float32)
batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
batch_size=batch_size)
for data in train_ds:
print(data.shape[0])
batch_size.assign(10)
print(batch_size)
Output
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...
I am training a model using a custom training loop using Gradient tape. How can I achieve this?
Upvotes: 11
Views: 9163
Reputation: 6367
Obviously if you're using .from_generator, you can manually batch things in there, but that doesn't really address your question.
The two easiest ways I can think of are to include the batch size as a component of the dataset, and then build batches of the requested size:
import tensorflow as tf
batch_sizes = tf.data.Dataset.range(4)
ds = batch_sizes.map(lambda n: tf.random.normal(shape=[n,3]))
for item in ds:
print(item.shape)
print()
(0,3)
(1,3)
(2,3)
(3,3)
Or, building on @PG-N's solution, if you need a version that runs totally inside a tf.function
you can pack them using tf.TensorArray
:
import tensorflow as tf
class Batcher(tf.Module):
def __init__(self, ds, batch_size=0):
self.it = iter(ds)
self._batch_size = tf.Variable(batch_size)
@property
def batch_size(self):
return self._batch_size
@batch_size.setter
def batch_size(self, new_size):
self._batch_size.assign(new_size)
@tf.function
def __call__(self):
examples =tf.TensorArray(dtype=tf.int64, size=self.batch_size)
for i in tf.range(self.batch_size):
examples = examples.write(i, next(self.it))
return examples.stack()
ds = tf.data.Dataset.range(100)
B = Batcher(ds)
B.batch_size = 5
B().numpy()
array([0, 1, 2, 3, 4])
B.batch_size = 10
B().numpy()
array([ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
B.batch_size = 3
B().numpy()
array([ 15, 16, 17])
You could probably do something in the middle using tf.nest
to make this general to datasets with more than just a single tensor component.
Also, depending on your use-case, methods like group_by_window
, and bucket_by_sequence_length
could be helpful. These do some multi-size batches, they could be what you're looking for, or the implementation could be a clue for your problem.
Upvotes: 1
Reputation: 24581
I don't think you can the way you used to in TF1.
A work-around could be to build the batch yourself by stacking individual samples:
import tensorflow as tf
ds = tf.data.Dataset.range(10).repeat()
iterator = iter(ds)
for batch_size in range(1, 10):
batch = tf.stack([iterator.next() for _ in range(batch_size)], axis=0)
print(batch)
# tf.Tensor([0], shape=(1,), dtype=int64)
# tf.Tensor([1 2], shape=(2,), dtype=int64)
# tf.Tensor([3 4 5], shape=(3,), dtype=int64)
# tf.Tensor([6 7 8 9], shape=(4,), dtype=int64)
# tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
# tf.Tensor([5 6 7 8 9 0], shape=(6,), dtype=int64)
# tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int64)
# tf.Tensor([8 9 0 1 2 3 4 5], shape=(8,), dtype=int64)
# tf.Tensor([6 7 8 9 0 1 2 3 4], shape=(9,), dtype=int64)
Upvotes: 4
Reputation: 3079
From what I know, you should instantiate a new dataset iterator to make your change take effect. This will require to tweak a little bit to skip already seen samples.
Here is my simplest solution:
import numpy as np
import tensorflow as tf
def get_dataset(batch_size, num_samples_seen):
return tf.data.Dataset.range(
100
).skip(
num_samples_seen
).batch(
batch_size=batch_size
)
def main():
batch_size = 1
num_samples_seen = 0
train_ds = get_dataset(batch_size, num_samples_seen)
ds_iterator = iter(train_ds)
while True:
try:
data = next(ds_iterator)
except StopIteration:
print("End of iteration")
break
print(data)
batch_size *= 2
num_samples_seen += data.shape[0]
ds_iterator = iter(get_dataset(batch_size, num_samples_seen))
print("New batch size:", batch_size)
if __name__ == "__main__":
main()
As you can see here, you have to instantiate a new dataset (through a call to get_dataset
) and update the iterator.
I don't know of the performance impact of such a solution. Maybe there is another solution requiring to "just" instantiate a batch
step instead of the whole dataset.
Upvotes: 1