EthanJiang
EthanJiang

Reputation: 43

tf.keras.layers.MultiHeadAttention's argument key_dim sometimes not matches to paper's example

For example, I have input with shape (1, 1000, 10) (so, src.shape wil be (1, 1000, 10), which means the sequence length is 1000, and the dimension is 10. Then:

class Model(tf.keras.Model):
        def __init__(self):
            super(Model, self).__init__()
            self.attention1 = tf.keras.layers.MultiHeadAttention(num_heads=20, key_dim=9)
            self.dense = tf.keras.layers.Dense(10, activation="softmax")

        def call(self, src) :
            output = self.attention1(src, src)
            output = tf.reshape(output, [1, 10000])
            output = self.dense(output)
            return output
class Model(tf.keras.Model):
        def __init__(self):
            super(Model, self).__init__()
            self.attention1 = tf.keras.layers.MultiHeadAttention(num_heads=123, key_dim=17)
            self.dense = tf.keras.layers.Dense(10, activation="softmax")

        def call(self, src):
            output = self.attention1(src, src)
            output = tf.reshape(output, [1, 10000])
            output = self.dense(output)
            return output

So, this layer works with whatever num_heads and key_dim, which does not match the paper idea. (It works because no error report, and it able to train)

In the paper, 'attention is all you need', it says key_dim is the dimension of key for each head, not the original head dimension, and thus key_dim should equal to embed_dim/head_num. So, if we want to have a head_num of 5, the key_dim has to be 2, if embedding_dim is 10.

the screen shot from the paper

Also, from the keras attention class discription, the key_dim is Size of each attention head for query and key, which matches to the paper idea.

the screen shot from the class discription

Therefore, why tf.keras.layers.MultiHeadAttention able to take unmatched dimension. When it takes the unmatching dimension, how does it work internally with these extra weight parameters?

Upvotes: 4

Views: 3134

Answers (1)

Anirban Mukherjee
Anirban Mukherjee

Reputation: 499

There are two dimensions d_k and d_v.

  • key_dim corresponds to d_k, which can be more or less than d_v. d_k is the size of the key and query dimensions for each head.
  • d_v = embed_dim/num_head. d_v is the size of the value for each head.

In their paper, Vaswani et al. set d_k = d_v. This, however, is not required.

Upvotes: 2

Related Questions