Reputation: 31
I load in a dataset as such:
import tensorflow_datasets as tfds
ds = tfds.load(
'caltech_birds2010',
split='train',
as_supervised=False)
And this function works fine:
import tensorflow as tf
@tf.function
def pad(image,label):
return (tf.image.resize_with_pad(image,32,32),label)
ds = ds.map(pad)
But when when I try mapping a different built-in function
from tf.keras.preprocessing.image import random_rotation
@tf.function
def rotate(image,label):
return (random_rotation(image,90), label)
ds = ds.map(rotate)
I get the following error:
AttributeError: 'Tensor' object has no attribute 'ndim'
This is not the only function giving me issues, and it happens with or without the @tf.function
decorator.
Any help is greatly appreciated!
Upvotes: 3
Views: 989
Reputation: 59
I would try using tf.py_function in here for the random_rotation. For eg:
def rotate(image, label):
im_shape = image.shape
[image, label,] = tf.py_function(random_rotate,[image, label],
[tf.float32, tf.string])
image.set_shape(im_shape)
return image, label
ds = ds.map(rotate)
Although I think they do similar things here according to What is the difference in purpose between tf.py_function and tf.function?, tf.py_function is more straightforward for executing python code through tensorflow even though tf.function has a performance advantage.
Upvotes: 2