Peyman
Peyman

Reputation: 4189

Tensorflow: 'axis' argument in dot product

Can someone show me the way I should use the axis argument in tf.tensordot?

I read the documentation but it was complicated and I'm still confused. I saw another question that asks about axis in tf.one_hot and in the answers were some good insights about the matter, but that didn't help me with tf.tensordot. I thought you can give me some insights on this too.

For example, I know I can dot product a vector and a tensor like this:

my_vector = tf.random.uniform(shape=[n])
my_tensor = tf.random.uniform(shape=[m, n])

dp = tf.tensordot(my_tensor, my_vector, 1)

But when I batch them and add one dimension to them to be of the shape (b, n) and (b, m, n) to obtain a (b, m, 1), now I don't know how to dot product every batch.

Upvotes: 1

Views: 1194

Answers (1)

javidcf
javidcf

Reputation: 59701

The operation that you want to do cannot be done (in an effective way) with tf.tensordot. There is, however, a dedicated function for that operation, tf.linalg.matvec, which will work with batches out of the box. And you can also do the same thing with tf.einsum, like tf.einsum('bmn,bn->bm', my_tensors, my_vectors).

With respect to tf.tensordot, in general it computes an "all vs all" product of the two given tensors, but matching and reducing some axes. When no axes are given (you have to explicitly pass axes=[[], []] to do this), it creates a tensor with the dimensions of both inputs concatenated. So, if you have my_tensors with shape (b, m, n) and my_vectors with shape (b, n) and you do:

res = tf.tensordot(my_tensors, my_vectors, axes=[[], []])

You get res with shape (b, m, n, b, n), such that res[p, q, r, s, t] == my_tensors[p, q, r] * my_vectors[s, t].

The axes argument is used to specify dimensions in the input tensors that are "matched". Values along matched axes are multiplied and summed (like a dot product), so those matched dimensions are reduced from the output. axes can take two different forms:

  • If it is a single integer, N then the last N dimensions of the first parameter are matched against the first N dimensions of b. In your example, that corresponds to the dimensions with n elements in my_tensor and my_vector.
  • If it is a list, it must contain two sublists, axes_a and axes_b, each with the same number N of integers. In this form, you are explicitly indicating which dimensions of the given values are matched. So, in your example, you could have passed axes=[[1], [0]], which means "match the dimension 1 of the first parameter (my_tensor) to the dimension 0 of the second parameter (my_vector)".

If you have now my_tensors with shape (b, m, n) and my_vectors with shape (b, n), then you would want to match the dimension 2 of the first one to the dimension 1 of the second one, so you could pass axes=[[2], [1]]. However, that will give you a result res with shape (b, m, b) such that res[i, :, j] is the product of matrix my_tensors[i] and vector my_vectors[j]. You could take then only the results that you want (those where i == j), with something more or less convoluted like tf.transpose(tf.linalg.diag_part(tf.transpose(res, [1, 0, 2]))), but you would be doing far more computation than you need to get the same result.

Upvotes: 2

Related Questions