Reputation: 5945
In tensorflow, how can I apply dynamic shape to scatter_nd
?
When I use an input tensor with a dynamic shape, I get the following error:
ValueError: Cannot convert a partially known TensorShape to a Tensor: (20, ?)
Here is the function I use. It works when tensor
has a static shape. But with a dynamic shape (e.g. (?, 7)
) it fails.
def tf_zero_pad_columns(tensor, columns_list, num_output_columns):
assert(tensor.shape.as_list()[1] == len(columns_list))
assert(num_output_columns >= len(columns_list))
tensor = tf.transpose(tensor)
columns = tf.constant(np.array([columns_list]).T.astype('int32'))
shape=tf.TensorShape((num_output_columns, tensor.get_shape()[1]))
scattered = tf.scatter_nd(columns, tensor, shape=shape)
return tf.transpose(scattered)
I also tried replacing tensor.get_shape()[1]
by -1
but this produces a different error during training:
InvalidArgumentError: Dimension -1 must be >= 0 [[Node: lambda_40/ScatterNd ....
EDIT:
Example input with a dynamic shape (this reproduces the error):
tensor = tf.placeholder(tf.float32, shape=(None, 7))
tf_zero_pad_columns(tensor, [11,12,13,4,5,6,7], 20)
Example input with a static shape:
import numpy as np
tensor_np = np.tile(range(7), (4, 1)) + np.array(range(4))[:, None]
tensor = tf.constant(tensor_np)
tf_zero_pad_columns(tensor, [11,12,13,4,5,6,7], 20)
Output is:
array([[0, 0, 0, 0, 3, 4, 5, 6, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 5, 6, 7, 8, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 6, 7, 8, 9, 0, 0, 0, 3, 4, 5, 0, 0, 0, 0, 0, 0]])
Upvotes: 1
Views: 491
Reputation: 5555
This works for me:
def tf_zero_pad_columns(tensor, columns_list, num_output_columns):
assert(tensor.shape.as_list()[1] == len(columns_list))
assert(num_output_columns >= len(columns_list))
tensor = tf.transpose(tensor)
columns = tf.constant(np.array([columns_list]).T.astype('int32'))
tensor_shape = tf.shape(tensor)[1]
scattered = tf.scatter_nd(columns, tensor, shape=(num_output_columns, tensor_shape))
return tf.transpose(scattered)
Upvotes: 1