user3034281
user3034281

Reputation: 307

Tensorflow multiply 3D batch tensor with a 2D weight

I've got two tensors with the shape shown below,

batch.shape = [?, 5, 4]
weight.shape = [3, 5]

by multiplying the weight with every element in the batch, I want to get

result.shape = [?, 3, 4]

what is the most efficient way to achieve this?

Upvotes: 0

Views: 1540

Answers (2)

user1531248
user1531248

Reputation: 531

Try this:

tf.einsum("ijk,aj-> iak",batch,weight)

A generalized contraction between tensors of arbitrary dimension Refer this for more information

Upvotes: 0

asakryukin
asakryukin

Reputation: 2594

Try this:

newbatch = tf.transpose(batch,[1,0,2])
newbatch = tf.reshape(newbatch,[5,-1])
result = tf.matmul(weight,newbatch)
result = tf.reshape(result,[3,-1,4])
result = tf.transpose(result, [1,0,2])

Or more compactly:

newbatch = tf.reshape(tf.transpose(batch,[1,0,2]),[5,-1])
result = tf.transpose(tf.reshape(tf.matmul(weight,newbatch),[3,-1,4]), [1,0,2])

Upvotes: 1

Related Questions