bluesummers
bluesummers

Reputation: 12607

TensorFlow concatenate/stack N tensors interleaving last dimensions

Assume we have 4 tensors, a, b, c and d which all share the same dimensions of (batch_size, T, C), we want to create a new tensor X which has the shape (batch_size, T*4, C) where the T*4 is interleaved looping between all of the tensors.

For example, if a, b, c and d were tensors of all ones, twos, threes and fours respectively we'd expect X to be something like

[[[1,1,1...],
  [2,2,2...],
  [3,3,3...],
  [4,4,4...],
  [1,1,1...],
  [2,2,2...],
  .
  .
  .
]]

Upvotes: 2

Views: 1185

Answers (2)

Ohad Meir
Ohad Meir

Reputation: 714

I think another option is to use tf.tile.

import tensorflow as tf

tf.enable_eager_execution()

A = tf.ones((2, 1, 4))
B = tf.ones((2, 1, 4)) * 2
C = tf.ones((2, 1, 4)) * 3
ABC = tf.concat([A, B, C], axis=1)

print(ABC)
#tf.Tensor(
#[[[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]
#
# [[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]], shape=(2, 3, 4), dtype=float32)

X = tf.tile(ABC, multiples=[1, 3, 1])

print(X)
#tf.Tensor(
#[[[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]
#
# [[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]], shape=(2, 9, 4), dtype=float32)

Upvotes: 1

tomkot
tomkot

Reputation: 956

It seems to me that your example array actually has the shape (batch_size, T, C*4) rather than (batch_size, T*4, C). Anyway, you can get what you need with tf.concat, tf.reshape, and tf.transpose. A simpler example in 2d is as follows:

A = tf.ones([2,3])
B = tf.ones([2,3]) * 2
AB = tf.concat([A,B], axis=1)
AB = tf.reshape(AB, [-1, 3])
AB.eval() #array([[1., 1., 1.],
   # [2., 2., 2.],
   # [1., 1., 1.],
   # [2., 2., 2.]], dtype=float32)

You concatenate A and B to get a matrix of shape (2,6). Then you reshape it which interleaves the rows. To do this in 3d, the dimension which is multiplied by 4 needs to be the last one. So you may need to use tf.transpose, interleave using concat and reshape, then transpose again to reorder the dimensions.

Upvotes: 2

Related Questions