Reputation: 113
Code:
x = tf.constant([1.,2.,3.], shape = (3,2,4))
y = tf.constant([1.,2.,3.], shape = (3,21,4))
tf.matmul(x,y) # Doesn't work.
tf.matmul(x,y,transpose_b = True) # This works. Shape is (3,2,21)
tf.matmul(x,tf.transpose(y)) # Doesn't work.
I want to know what shape y
becomes inside tf.matmul(x,y,transpose_b = True)
so I can work out what is really going on inside an LSTM with attention.
Upvotes: 7
Views: 6664
Reputation: 53768
Transpose can be defined differently for tensors of rank > 2, and here the difference is in axes that are transposed by tf.transpose
and tf.matmul(..., transpose_b=True)
.
By default, tf.transpose
does this:
The returned tensor's dimension
i
will correspond to the input dimensionperm[i]
. If perm is not given, it is set to(n-1...0)
, where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors.
So in your case, it's going to transform y
into a tensor of shape (4, 21, 3)
, which is not compatible with x
(see below).
But if you set perm=[0, 2, 1]
, the result is compatible:
# Works! (3, 2, 4) * (3, 4, 21) -> (3, 2, 21).
tf.matmul(x, tf.transpose(y, [0, 2, 1]))
tf.matmul
You can compute the dot product: (a, b, c) * (a, c, d) -> (a, b, d)
. But it's not tensor dot product -- it's a batch operation (see this question).
In this case, a
is considered a batch size, so tf.matmul
computes a
dot-products of matrices (b, c) * (c, d)
.
Batch can be more than one dimension, so this is also valid:
(a, b, c, d) * (a, b, d, e) -> (a, b, c, e)
Upvotes: 7