Reputation: 1466
I want to use the tf.data
API. My expected workflow looks like the following:
Input image is a 5D tensor with (batch_size, width, height,
channels, frames)
First layer is a 3D convolution
I use the tf.data.from_generator
function to create an iterator. Later I make a initializable iterator.
My code would look something like this:
def custom_gen():
img = np.random.normal((width, height, channels, frames))
yield(img, img) # I train an autoencoder, so the x == y`
dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()
sess = tf.Session()
sess.run(iter.get_next())
I would expect that iter.get_next()
yielded me a 5D tensor with the batch size. However, I even tried to yield the batch size in my own custom_generator
and it does not work. I face an error, when I want to initialize the dataset with my placeholder of the input shape (batch_size, width, height, channels, frames)
.
Upvotes: 4
Views: 7797
Reputation: 29972
The Dataset
construction process in that example is ill-formed. It should be done in this order, as also established by the official guide on Importing Data:
from_slice_tensors
, from_generator
, list_files
, ...).batch
).Thus, in TensorFlow 2.9:
dataset = tf.data.Dataset.from_generator(
generator,
output_signature=(
tf.TensorSpec(shape=(width, height, channels, frames), dtype=tf.float32),
tf.TensorSpec(shape=(width, height, channels, frames), dtype=tf.float32)
).batch(batch_size)
In TensorFlow 1:
dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)
Upvotes: 5