Reputation: 1753
I am writing a function to implement a model in TensorFlow 2.0. It takes image_batch
(a batch of image data in numpy RGB format) and performs some specific data augmentation task that I need. The line that is causing me problems is:
@tf.function
def augment_data(image_batch, labels):
import numpy as np
from tensorflow.image import flip_left_right
image_batch = np.append(image_batch, flip_left_right(image_batch), axis=0)
[ ... ]
numpy
's .append()
function doesn't work anymore when I put the @tf.function
decorator on top of it. It returns:
ValueError: zero-dimensional arrays cannot be concatenated
When I use the np.append()
command outside of the function, or without the @tf.function
on top, the code runs without problems.
Is this normal? Am I forced to remove the decorator to make it work? Or is this a bug, due to the fact that TensorFlow 2.0 is still a beta version? In that case, how can I solve this?
Upvotes: 3
Views: 991
Reputation: 4543
Just wrap numpy ops into tf.py_function
def append(image_batch, tf_func):
return np.append(image_batch, tf_func, axis=0)
@tf.function
def augment_data(image_batch):
image = tf.py_function(append, inp=[image_batch, tf.image.flip_left_right(image_batch)], Tout=[tf.float32])
return image
Upvotes: 3