user3933614
user3933614

Reputation: 113

Why does tf.matmul(a,b, transpose_b=True) work, but not tf.matmul(a, tf.transpose(b))?

Code:

x = tf.constant([1.,2.,3.], shape = (3,2,4))
y = tf.constant([1.,2.,3.], shape = (3,21,4))
tf.matmul(x,y)                     # Doesn't work. 
tf.matmul(x,y,transpose_b = True)  # This works. Shape is (3,2,21)
tf.matmul(x,tf.transpose(y))       # Doesn't work.

I want to know what shape y becomes inside tf.matmul(x,y,transpose_b = True) so I can work out what is really going on inside an LSTM with attention.

Upvotes: 7

Views: 6664

Answers (1)

Maxim
Maxim

Reputation: 53768

Transpose can be defined differently for tensors of rank > 2, and here the difference is in axes that are transposed by tf.transpose and tf.matmul(..., transpose_b=True).

By default, tf.transpose does this:

The returned tensor's dimension i will correspond to the input dimension perm[i]. If perm is not given, it is set to (n-1...0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors.

So in your case, it's going to transform y into a tensor of shape (4, 21, 3), which is not compatible with x (see below).

But if you set perm=[0, 2, 1], the result is compatible:

# Works! (3, 2, 4) * (3, 4, 21) -> (3, 2, 21).
tf.matmul(x, tf.transpose(y, [0, 2, 1]))

About tf.matmul

You can compute the dot product: (a, b, c) * (a, c, d) -> (a, b, d). But it's not tensor dot product -- it's a batch operation (see this question).

In this case, a is considered a batch size, so tf.matmul computes a dot-products of matrices (b, c) * (c, d).

Batch can be more than one dimension, so this is also valid:

(a, b, c, d) * (a, b, d, e) -> (a, b, c, e)

Upvotes: 7

Related Questions