Lau
Lau

Reputation: 1466

How to make tf.data.Dataset.from_generator yield batches with a custom generator

I want to use the tf.data API. My expected workflow looks like the following:

I use the tf.data.from_generator function to create an iterator. Later I make a initializable iterator.

My code would look something like this:

def custom_gen():
   img = np.random.normal((width, height, channels, frames))
   yield(img, img) # I train an autoencoder, so the x == y`

dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()

sess = tf.Session()
sess.run(iter.get_next())

I would expect that iter.get_next() yielded me a 5D tensor with the batch size. However, I even tried to yield the batch size in my own custom_generator and it does not work. I face an error, when I want to initialize the dataset with my placeholder of the input shape (batch_size, width, height, channels, frames).

Upvotes: 4

Views: 7797

Answers (1)

E_net4
E_net4

Reputation: 29972

The Dataset construction process in that example is ill-formed. It should be done in this order, as also established by the official guide on Importing Data:

  1. A base dataset creation function or static method should be called for establishing the original source of data (e.g. the static methods from_slice_tensors, from_generator, list_files, ...).
  2. At this point, transformations may be applied by chaining adapter methods (such as batch).

Thus, in TensorFlow 2.9:

dataset = tf.data.Dataset.from_generator(
     generator,
     output_signature=(
         tf.TensorSpec(shape=(width, height, channels, frames), dtype=tf.float32),
         tf.TensorSpec(shape=(width, height, channels, frames), dtype=tf.float32)
    ).batch(batch_size)

In TensorFlow 1:

dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)

Upvotes: 5

Related Questions