ClementWalter
ClementWalter

Reputation: 5324

How to repeat a tf.data.Dataset

The tf.data.Dataset actually has a repeat method that outputs what is much more like a tile, ie that:

list(tf.data.Dataset.range(2).repeat(3).as_numpy_iterator())
# [0, 1, 0, 1, 0, 1]

is like:

np.tile(np.arange(2), 3)
# array([0, 1, 0, 1, 0, 1])

What I am looking for is actual numpy repeat:

np.repeat(np.arange(2), 3)
# array([0, 0, 0, 1, 1, 1])

Upvotes: 3

Views: 873

Answers (1)

user11530462
user11530462

Reputation:

We can do it as below,

  1. Each element is repeated using tf.repeat in map function.
  2. Flatten the result using flat_map.

Code -

%tensorflow_version 2.x
import tensorflow as tf

dataset = tf.data.Dataset.range(2).map(lambda x : tf.repeat(x,3)).flat_map(lambda y: tf.data.Dataset.from_tensor_slices(y))

list(dataset.as_numpy_iterator())

Output -

[0, 0, 0, 1, 1, 1]

Hope this answers your question. Happy Learning.

Upvotes: 2

Related Questions