Jian Liu
Jian Liu

Reputation: 41

In tf.function, how to convert dynamic tensor shape getted by tf.shape() to python values but not tensors itself?

I have a padded batch from tf.dataset, because every padded batch's shape is not fixed.So I have to use tf.shape method to get the dynamic shape of padded batch.The question is how can I convert the tensor shape getted by tf.shape to python values under tf.function?

@tf.function
def train_step(padded_batch):
    shape = tf.shape(padded_batch)
    x = np.zeros(shape[0], shape[1])

As the above code, I want to create a numpy array as the same shape of padded_batch,but 'shape' is a tensor, it can't be used directly in numpy.If there is someway to convert tensor to python values under tf.function.

The tensorflow version I use is tf2.0

Upvotes: 4

Views: 1513

Answers (2)

EyesBear
EyesBear

Reputation: 1466

As described in TF documents,

within @tf.function or within a compat.v1 context, not all dimensions may be known until execution time. Hence when defining custom layers and models for graph mode, prefer the dynamic tf.shape(x) over the static x.shape

Your code was ok. I just replaced np.zeros with tf.zeros. The @tf.function decorator means the code will run in graph mode. numpy is not allowed within graph. Tested in TF 2.x.

@tf.function
def train_step(padded_batch):
    shape = tf.shape(padded_batch)
    return tf.zeros((shape[0], shape[1]))

Upvotes: 0

ShlomiF
ShlomiF

Reputation: 2895

assuming you have a tensor named a_tensor:

this_is_a_regular_non_tensor_shape = a_tensor.shape.as_list()

(BTW: you don't seem to be using np.zeros correctly...you need to pass the shape as a single tuple/list argument. Not separate arguments for each dimension. For instance:

shape = padded_batch.shape.as_list()
x = np.zeros(shape)

Hope that helps.)

Upvotes: 2

Related Questions