Reputation: 4189
Can someone show me the way I should use the axis
argument in tf.tensordot
?
I read the documentation but it was complicated and I'm still confused. I saw another question that asks about axis
in tf.one_hot
and in the answers were some good insights about the matter, but that didn't help me with tf.tensordot
. I thought you can give me some insights on this too.
For example, I know I can dot product a vector and a tensor like this:
my_vector = tf.random.uniform(shape=[n])
my_tensor = tf.random.uniform(shape=[m, n])
dp = tf.tensordot(my_tensor, my_vector, 1)
But when I batch them and add one dimension to them to be of the shape (b, n)
and (b, m, n)
to obtain a (b, m, 1)
, now I don't know how to dot product every batch.
Upvotes: 1
Views: 1194
Reputation: 59701
The operation that you want to do cannot be done (in an effective way) with tf.tensordot
. There is, however, a dedicated function for that operation, tf.linalg.matvec
, which will work with batches out of the box. And you can also do the same thing with tf.einsum
, like tf.einsum('bmn,bn->bm', my_tensors, my_vectors)
.
With respect to tf.tensordot
, in general it computes an "all vs all" product of the two given tensors, but matching and reducing some axes. When no axes are given (you have to explicitly pass axes=[[], []]
to do this), it creates a tensor with the dimensions of both inputs concatenated. So, if you have my_tensors
with shape (b, m, n)
and my_vectors
with shape (b, n)
and you do:
res = tf.tensordot(my_tensors, my_vectors, axes=[[], []])
You get res
with shape (b, m, n, b, n)
, such that res[p, q, r, s, t] == my_tensors[p, q, r] * my_vectors[s, t]
.
The axes
argument is used to specify dimensions in the input tensors that are "matched". Values along matched axes are multiplied and summed (like a dot product), so those matched dimensions are reduced from the output. axes
can take two different forms:
N
then the last N
dimensions of the first parameter are matched against the first N
dimensions of b
. In your example, that corresponds to the dimensions with n
elements in my_tensor
and my_vector
.axes_a
and axes_b
, each with the same number N
of integers. In this form, you are explicitly indicating which dimensions of the given values are matched. So, in your example, you could have passed axes=[[1], [0]]
, which means "match the dimension 1
of the first parameter (my_tensor
) to the dimension 0
of the second parameter (my_vector
)".If you have now my_tensors
with shape (b, m, n)
and my_vectors
with shape (b, n)
, then you would want to match the dimension 2
of the first one to the dimension 1
of the second one, so you could pass axes=[[2], [1]]
. However, that will give you a result res
with shape (b, m, b)
such that res[i, :, j]
is the product of matrix my_tensors[i]
and vector my_vectors[j]
. You could take then only the results that you want (those where i == j
), with something more or less convoluted like tf.transpose(tf.linalg.diag_part(tf.transpose(res, [1, 0, 2])))
, but you would be doing far more computation than you need to get the same result.
Upvotes: 2