user3521479
user3521479

Reputation: 585

Reshape batch of tensors into batch of vectors in TensorFlow

During creating a calculation graph I have a tensor x e.g. with shape of [-1, a, b, c] and I would like to reshape it into [-1, a*b*c] I tried to do it this way:

n = functools.reduce(operator.mul, x.shape[1:], 1)
tf.reshape(x, [-1, n])

but I've got an error:

TypeError: unsupported operand type(s) for *: 'int' and 'Dimension'

My question is: is there is TensorFlow something to do this operation?

Upvotes: 0

Views: 1084

Answers (1)

hbaderts
hbaderts

Reputation: 14371

As the error message tells you, there is a problem with the types. If you create a TensorFlow placeholder, e.g. with

>>> import tensorflow as tf
>>> x = tf.placeholder(tf.float16, shape=(None, 3,7,4))

and call shape on it, then the return value is

>>> x.shape
TensorShape([Dimension(None), Dimension(3), Dimension(7), Dimension(4)])

and each element is a

>>> x.shape[1]
<class 'tensorflow.python.framework.tensor_shape.Dimension'>

i.e. a Dimension class of TensorFlow. Naturally, the operator.mul function doesn't know what to do with such a type. Luckily, the tf.TensorShape has a as_list() function, which returns the shape as a list of integers.

>>> x.shape.as_list()
[None, 3, 7, 4]

With that, you can calculate the number of elements n, as you're used to:

>>> import functools, operator
>>> n = functools.reduce(operator.mul, x.shape.as_list()[1:], 1)
>>> n 
84
>>> y = tf.reshape(x, [-1, n])

Upvotes: 3

Related Questions