Leevo
Leevo

Reputation: 1753

TensorFlow 2.0: Function with @tf.function decorator doesn't take numpy functions

I am writing a function to implement a model in TensorFlow 2.0. It takes image_batch (a batch of image data in numpy RGB format) and performs some specific data augmentation task that I need. The line that is causing me problems is:

@tf.function
def augment_data(image_batch, labels):
    import numpy as np
    from tensorflow.image import flip_left_right

    image_batch = np.append(image_batch, flip_left_right(image_batch), axis=0)

    [ ... ]

numpy's .append() function doesn't work anymore when I put the @tf.function decorator on top of it. It returns:

ValueError: zero-dimensional arrays cannot be concatenated

When I use the np.append() command outside of the function, or without the @tf.function on top, the code runs without problems.

Is this normal? Am I forced to remove the decorator to make it work? Or is this a bug, due to the fact that TensorFlow 2.0 is still a beta version? In that case, how can I solve this?

Upvotes: 3

Views: 991

Answers (1)

Sharky
Sharky

Reputation: 4543

Just wrap numpy ops into tf.py_function

def append(image_batch, tf_func):
    return np.append(image_batch, tf_func, axis=0)

@tf.function
def augment_data(image_batch):
    image = tf.py_function(append, inp=[image_batch, tf.image.flip_left_right(image_batch)], Tout=[tf.float32])
    return image

Upvotes: 3

Related Questions