bcsta
bcsta

Reputation: 2327

Understanding Bahdanau's Attention Linear Algebra

Bahdanau's Additive Attention is recognized as the second part of equation 4 in the below image.


enter image description here


I am trying to figure out the shapes of the matrices w1, w2, ht, hs and v in order to figure out how this mechanism is used in this paper

  1. Can ht and hs have different final dimensions? say (batch size, total units) and (batch size, time window). Equation 8 in the mentioned paper above seem to be doing this.

  2. Equation 8 in the above paper has the below notation:

    enter image description here

what does this expand to exactly?

(W1 . ht-1) + (W1 . Ct-1)

or

W1 . concatenation(ht-1, ct-1)

I have seen both being used. Any quick explanations of the above matrix shapes is much appreciated.

Upvotes: 1

Views: 736

Answers (3)

Allohvk
Allohvk

Reputation: 1374

Maybe understanding this with a specific example may help: Let us say you have a 19 word tweet and you want to convert it into another language. You create embeddings for the words and then pass it thru' a bi-directional LSTM layer of 128 units. The encoder now outputs 19 hidden states of 256 dimensions for every tweet. Let us say the decoder is uni-directional and has 128 units. It starts translating the words while parallely outputting a hidden state at each time step.

Now you want to bring in Bahdanau's attention to the above equation. You want to feed s_tminus1 of the decoder and all hidden states of the encoder (hj) and want to get the context using the following steps:

generate v * (w * s_tminus1 + u*hj)

Take a softmax of the above to get the 19 attention weights for each tweet and then multiply these attention weights by the encoder hidden states to get the weighted sum which is nothing but the context.

Note that in Bahdanau model the decoder should be unidirectional. Then the shapes would be as follows:

Assume n=10 units for the alignment layer to determine w,u. Then: the shapes for s_tminus1 and hj would be (?,128) and (?,19,256). Note that s_tminus1 is the single decoder hidden state at t-1 and hj are the 19 hidden states of the bi-directional encoder.

We have to expand stminus1 to (?,1,128) for the addition that follows later along the time axis. The layer weights for w,u,v will be automatically determined by the framework as (?,128,10), (?,256,10) and (?,10,1) respectively. Notice how self.w(stminus1) works out to (?,1,10). This is added to each of the self.u(hj) to give a shape of (?,19,10). The result is fed to self.v and the output is (?,19,1) which is the shape we want - a set of 19 weights. Softmaxing this gives the attention weights.

Multiplying this attention weight with each encoder hidden state and summing up returns the context.

Hope this clarifies on the shapes of the various tensors and weight shapes.

To answer your other questions - the dimensions of ht and hs can be different as shown in above example. As to your other question, I have seen the 2 vectors being concatenated and then a single weight applied on them..at least this is what I remember reading in the original paper

Upvotes: 1

xdurch0
xdurch0

Reputation: 10474

To more directly answer your questions:

  1. ht and hs can have a different shape. The important thing is that after the matrix multiplication, they are the same, otherwise they cannot be added together. That is, W1 and W2 need to map to the same dimension size.
  2. This should be taken as concatenation of h and c. I don't think multiplying both by the same matrix and adding makes a lot of sense.

Note: Part 1) can also be implemented via concantenating ht and hs in the feature dimension and applying a single matrix multiplication. This may be more efficient than two separate ones.

Upvotes: 0

I_Al-thamary
I_Al-thamary

Reputation: 4088

I found this more helpful where it shows the output of every equation and how can the shape of the encoder and decoder. Flow of calculating Attention weights in Bahdanau Attention enter image description here

We can see that you can get a different shape of encoder or decoder and Attention focuses on the most important parts of the sequence instead of the entire sequence as a whole. Also, you can use this code where it shows how to apply these equations

FC = Fully connected (dense) layer EO = Encoder output H = hidden state X = input to the decoder

score = FC(tanh(FC(EO) + FC(H)))

Upvotes: 0

Related Questions