432Hz
432Hz

Reputation: 31

TensorFlow dataset .map() method not working for built-in tf.keras.preprocessing.image functions

I load in a dataset as such:

import tensorflow_datasets as tfds

ds = tfds.load(
    'caltech_birds2010',
    split='train',
    as_supervised=False)

And this function works fine:

import tensorflow as tf

@tf.function
def pad(image,label):
    return (tf.image.resize_with_pad(image,32,32),label)

ds = ds.map(pad)

But when when I try mapping a different built-in function

from tf.keras.preprocessing.image import random_rotation

@tf.function
def rotate(image,label):
    return (random_rotation(image,90), label)

ds = ds.map(rotate)

I get the following error:

AttributeError: 'Tensor' object has no attribute 'ndim'

This is not the only function giving me issues, and it happens with or without the @tf.function decorator.

Any help is greatly appreciated!

Upvotes: 3

Views: 989

Answers (1)

bugvig_pk
bugvig_pk

Reputation: 59

I would try using tf.py_function in here for the random_rotation. For eg:

def rotate(image, label):
    im_shape = image.shape
    [image, label,] = tf.py_function(random_rotate,[image, label],
                                     [tf.float32, tf.string])
    image.set_shape(im_shape)
    return image, label

ds = ds.map(rotate)

Although I think they do similar things here according to What is the difference in purpose between tf.py_function and tf.function?, tf.py_function is more straightforward for executing python code through tensorflow even though tf.function has a performance advantage.

Upvotes: 2

Related Questions