cool customer
cool customer

Reputation: 21

Unable to decode batch of images with tf.io.decode_image

I am trying to run facial keypoints regression. I successfully created TFRecord file with images and labels encoded in it (labels are facial kypoints).

Then, I started loading the data (images and keypoints) into memory (following the guide from here https://gist.github.com/FirefoxMetzger/c143c340c71e85c0c23c7ced94a88c16#file-faster_fully_connected_reader-py). I wanted to batch first all images and then decode images as it is described in that guide. However this does not work. If my understanding is correct I can only use tf.image.decode_image() on a single image not on the batch. Is my understanding correct? If yes how can I decode batch of images?

Thank you in advance!

CC

Here is the code:

ds = tf.data.TFRecordDataset(TFR_FILENAME)

ds = ds.repeat(EPOCHS)

ds = ds.shuffle(BUFFER_SIZE + BATCH_SIZE)

ds = ds.batch(BATCH_SIZE)

finally I tried to decode the image using tf.image.decode_image()

feature_description = {'height': tf.io.FixedLenFeature([], tf.int64),
                    'width': tf.io.FixedLenFeature([], tf.int64),
                    'depth': tf.io.FixedLenFeature([], tf.int64),
                    'kpts': tf.io.FixedLenFeature([136], tf.float32),
                    'image_raw': tf.io.FixedLenFeature([], tf.string),
                    }

for record in ds.take(1):
    record = tf.io.parse_example(record, feature_description)
    decoded_image = tf.io.decode_image(record['image_raw'], dtype=tf.float32)

This throws the following ValueError:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-3583be9c40ab> in <module>()
      1 for record in ds.take(1):
      2     record = tf.io.parse_example(record, feature_description)
----> 3     decoded_image = tf.io.decode_image(record['image_raw'], dtype=tf.float32)

3 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/image_ops_impl.py in decode_image(contents, channels, dtype, name, expand_animations)
   2315     # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1).
   2316     return control_flow_ops.cond(
-> 2317         is_jpeg(contents), _jpeg, check_png, name='cond_jpeg')
   2318 
   2319 

/tensorflow-2.0.0/python3.6/tensorflow_core/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

/tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   1199   with ops.name_scope(name, "cond", [pred]):
   1200     if context.executing_eagerly():
-> 1201       if pred:
   1202         result = true_fn()
   1203       else:

/tensorflow-2.0.0/python3.6/tensorflow_core/python/framework/ops.py in __bool__(self)
    874 
    875   def __bool__(self):
--> 876     return bool(self._numpy())
    877 
    878   __nonzero__ = __bool__

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Upvotes: 1

Views: 2432

Answers (1)

Dan Moldovan
Dan Moldovan

Reputation: 975

Indeed, decode_image only works with single images. You should still get reasonable performance by doing the decoding in the dataset, before batching.

Something like this (code not tested, might need some tweaks):

ds = tf.data.TFRecordDataset(TFR_FILENAME)

def parse_and_decode(record):
  record = tf.io.parse_example(record, feature_description)
  record['image'] = tf.io.decode_image(record['image_raw'], dtype=tf.float32)
  return record

ds = ds.map(parse_and_decode)

ds = ds.repeat(EPOCHS)

ds = ds.shuffle(BUFFER_SIZE + BATCH_SIZE)
...

Upvotes: 1

Related Questions