oldsqlwnb
oldsqlwnb

Reputation: 21

iterator.get_next() returns byte array / iterator.get_next() cannot be assigned to multiple values without eager execution

I have a problem where I try to create a tf.Dataset from a tfrecord file via tf.data.TFRecordDataset.

def parse_function(example_proto):
# Defaults are not specified since both keys are required. 
keys_to_features={
      'image': tf.FixedLenFeature([1024*1024],tf.int64),
      'label': tf.FixedLenFeature([1024*1024],tf.int64)
}
features = tf.parse_example([example_proto],keys_to_features)
label = features['label']
image = features['image']
label = tf.reshape(label,(1024,1024))
image = tf.reshape(image,(1024,1024))
return image,label

def make_batch(batch_size):
    filenames = ["train.tfrecords"]
    tf.data.TFRecordDataset(filenames).repeat()
    dataset.map(map_func=parse_function,num_parallel_calls=batch_size)
    dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    image , label  = iterator.get_next()
    return image , label

This caused the error:

Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn.

So I changed : image , label = iterator.get_next() to : next_elem = iterator.get_next()

With this I could execute this code:

with tf.Session() as sess: 
sess.run(tf.global_variables_initializer())
next_elem   = sess.run( make_batch(1))

However, next_elem is a array of bytes instead of a tuple with shape ([1024,1024],[1024,1024]).

Upvotes: 0

Views: 668

Answers (1)

oldsqlwnb
oldsqlwnb

Reputation: 21

So turned out the error was just a misunderstanding on my part.

dataset.map(map_func=parse_function,num_parallel_calls=batch_size)
dataset.batch(batch_size)

does not manipulate the dataset itself see: Iterator.get_next() returning a tensor of shape ()

You have to actually assignt the dataset resulting form those operation to the dataset again like so: dataset = dataset.map(map_func=parse_function,num_parallel_calls=batch_size) dataset = dataset.batch(batch_size)

This actually also resolved the iterator.get_next() issue. So I changed next_elem = iterator.get_next() back to: image , label = iterator.get_next()

and with this follwing code works as expected: with tf.Session() as sess: sess.run(tf.global_variables_initializer()) image , label = sess.run( make_batch(1))

Upvotes: 1

Related Questions