Reputation: 1116
I've been trying to convert a generator I built to a tf.data.dataset. I've come far and now I have something simple like this
def parse_image(filename):
file = tf.io.read_file(filename) # this will work only with filename as tensor
image = tf.image.decode_image(file)
return image
def transform_img(img):
img = parse_image(img).numpy()
img = transforms_train(image = img)["image"]
return img
transform img works as expected when I call it on a filename itself. like:
plt.imshow(transform_img(array_of_filenames[0]))
but when I map it on a dataset
dataset = tf.data.Dataset.from_tensor_slices(array_of_filenames)
dataset = dataset.map(transform_img)
I get the error in the title.
I am doing something silly again aren't I? Thanks for helping!
Upvotes: 0
Views: 746
Reputation: 7745
It is not possible to use numpy inside the map function of tensorflow dataset. Otherwise, you need to wrap the function in tf.py_function
or tf.numpy_function
. So it should look like the following:
dataset = dataset.map(lambda: item: tf.py_function(transform_img, [item], [tf.float32]))
The first argument of py_function
is the preprocessing function you want, the second argument is the parameter to pass to the function. The final argument is the dtype of the return of preprocess function. (same applies to tf.numpy_function
)
I don't remember reading this in documentation but in a tutorial, you can find it here.
Upvotes: 1