Edward Atkins
Edward Atkins

Reputation: 466

Role of 'axes' in keras.backend.batch_dot

I have a custom layer to multiply two tensors A & B of size (x,1) & (1,y), where I want to produce an output C of size (x,y).

To take into account batching i.e. matrices size are actually (?,x,1) & (?,1,y), I am calling:

C = K.batch_dot(A,B, axes = [2,1])

This seems to producing the desired output, but I don't really understand what the axes variable represents here. My intuition is that these are the axes over which we want to perform the matrix multiplication, but I don't understand why it is in the order [2,1] rather than [1,2] (which produced an error).

Can anyone assist me in my understanding?

Upvotes: 2

Views: 1510

Answers (1)

Ayush Garg
Ayush Garg

Reputation: 161

As per the official documentation here

The lengths of axes[0] and axes[1] should be the same

In your case A has dimensions (?, x, 1) and B has dimensions (?, 1, y).

So its quite clear that from axis = [2, 1], second dimension of A i.e. 1 equals first dimensions of B i.e. 1 (axis dims starts from 0) and produces the desired results.

Upvotes: 2

Related Questions