John
John

Reputation: 33

In decorated tf function, I get error: 'Tensor' object has no attribute 'numpy'

I've looked all over but can't find anyone who's previous answers help.

I have a tensorflow model with an @tf.function in it that does the training (tf version 2.3.0). Within the train_step call, I need to pass the data from a tensor on to a numpy function that performs a cwt transform on it. There is (afaik) no tensorflow cwt, thus the need to pass this on to a numpy function. The issue I have is that within the @tf.function, the tensors are graphed, and thus one cannot directly call .numpy() to transform that tensor into a numpy array. Small code snippet shows the code, below.

My question is how can I transform my generated output data from my generator call into something I can pass on to this numpy function. Here's hoping there's a way to do this!

Thanks.

@tf.function
def train_step(self, true_data):
    noise = tf.random.uniform(shape=[1, 100, 1], minval=0, maxval=1, dtype=tf.float32) 
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_data = self.generator(noise)
        nump_data = generated_data.numpy()

<this line produces: AttributeError: 'Tensor' object has no attribute 'numpy'>

Upvotes: 2

Views: 796

Answers (1)

user11530462
user11530462

Reputation:

As you have mentioned, as per tf.function rules you can not use .numpy() functions inside tf.fucntion.
There is still some workaround you can do to convert Tensor to a NumPy array when graph mode is enabled using eval().

Below is the modified code which should help your cause.

@tf.function
def train_step(self, true_data):
    noise = tf.random.uniform(shape=[1, 100, 1], minval=0, maxval=1, dtype=tf.float32) 
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_data = self.generator(noise)
        nump_data = generated_data.eval(session=tf.compat.v1.Session())  

If the tf.function is not necessary for your code, you can directly run it with eager execution enabled instead of graph mode with your Tensorflow 2.3 version which enables eager execution by default. That will avoid these issues.

Upvotes: 1

Related Questions