Reputation: 307
I've got two tensors with the shape shown below,
batch.shape = [?, 5, 4]
weight.shape = [3, 5]
by multiplying the weight with every element in the batch, I want to get
result.shape = [?, 3, 4]
what is the most efficient way to achieve this?
Upvotes: 0
Views: 1540
Reputation: 531
Try this:
tf.einsum("ijk,aj-> iak",batch,weight)
A generalized contraction between tensors of arbitrary dimension Refer this for more information
Upvotes: 0
Reputation: 2594
Try this:
newbatch = tf.transpose(batch,[1,0,2])
newbatch = tf.reshape(newbatch,[5,-1])
result = tf.matmul(weight,newbatch)
result = tf.reshape(result,[3,-1,4])
result = tf.transpose(result, [1,0,2])
Or more compactly:
newbatch = tf.reshape(tf.transpose(batch,[1,0,2]),[5,-1])
result = tf.transpose(tf.reshape(tf.matmul(weight,newbatch),[3,-1,4]), [1,0,2])
Upvotes: 1