Reputation: 121
I am struggling to mask my input for the MultiHeadAttention Layer. I am using the Transformer Block from Keras documentation with self-attention. I could not find any example code online so far and would appreciate if someone could give me a code snippet.
The transformer block from this page:
class TransformerBlock(layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super(TransformerBlock, self).__init__()
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = keras.Sequential(
[layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
)
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = layers.Dropout(rate)
self.dropout2 = layers.Dropout(rate)
def call(self, inputs, training):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
The documentation for masking one can find under this link:
attention_mask: a boolean mask of shape [B, T, S], that prevents attention to certain positions. The boolean mask specifies which query elements can attend to which key elements, 1 indicates attention and 0 indicates no attention. Broadcasting can happen for the missing batch dimensions and the head dimension.
The only thing, I could get running is a mask created outside of the layer class as numpy array:
mask = np.ones((observations, sequence_length, sequence_length))
mask[X[:observations,:,0]==0]=0
Then input while calling the layer, with the only change in the transformer block being:
def call(self, inputs, mask, training):
attn_output = self.att(inputs, inputs, attention_mask=mask)
However, this does of course not work when given a batch_size while fitting and does only work for 5 observations with my memory, so it doesn't make any sense. Apart from that, I don't think this is masking the input properly - In general I am quite confused about how to mask, given the shape of the attention_mask (observations, sequence_length, sequence_length). The shape of my input is (observation, sequence_length, features). This input is being padded by zeros, however, when it comes to the transformer block, it has been already through an embedding layer and CNN. I have tried various ways to write a function, which creates the mask while training with different Tensor or Keras objects. However I am running each time into errors.
I hope someone more fluent in Tensorflow/Keras will be able to provide an example. Or somebody tells me that masking is useless given my architecture. The model is performing well. However, I hoped masking could help speed up the computing. And it just buggs me that I cannot get my head around it.
Upvotes: 11
Views: 8557
Reputation: 341
Maybe it is a little bit late, but for anyone who ends up on this post looking for a solution, this may help.
A typical scenario using a Transformer is in a NLP problem, where you have batches of sentences (let's assume that they are already tokenized for simplicity). Consider the following example:
sentences = [['Lorem', 'ipsum', 'dolor', 'sit', 'amet'], ['Integer', 'tincidunt', 'in', 'arcu', 'nec', 'fringilla', 'suscipit']]
As you can see, we have two sentences of different lenght. In order to learn from them in a tensorflow model, we can pad the shortest one with a special token, let's say '[PAD]'
, and feed them into a Transformer model, as you proposed. Hence:
sentences = tf.constant([['Lorem', 'ipsum', 'dolor', 'sit', 'amet', '[PAD]', '[PAD]'], ['Integer', 'tincidunt', 'in', 'arcu', 'nec', 'fringilla', 'suscipit']])
Also assuming that we already have a vocabulary of tokens extracted from some corpus, for example a vocabulary of 1000
tokens, we can define a StringLookup
layer that converts our batch of sentences into their numerical projections given the vocabulary. And we can specify which token is used for masking.
lookup = tf.keras.layers.StringLookup(vocabulary=vocabulary, mask_token='[PAD]')
x = lookup(sentences)
# x is a tf.Tensor([[2, 150, 19, 997, 9, 0, 0], [72, 14, 1, 1, 960, 58, 87]], shape=(2, 7), dtype=int64)
where we can see that the [PAD]
token maps to the 0 value in the vocabulary.
A typical next step is to feed this Tensor into an Embedding
layer, something like this:
embedding = tf.keras.layers.Embedding(input_dim=lookup.vocabulary_size(), output_dim=64, mask_zero=True)
The key here is the argument mask_zero
. As per the documentation, this argument means:
Boolean, whether or not the input value 0 is a special "padding" value that should be masked out...
This allows the embedding
layer to generate a mask for the subsequent layers to indicate which positions should be attended and which should not. This mask can be accessed via:
mask = embedding.compute_mask(sentences)
# mask is a tf.Tensor([[True, True, True, True, True, False, False], [True, True, True, True, True, True, True]], shape=(2, 7), dtype=bool)
The tensor of the embeddings is of the form:
y = embedding(sentences)
# y is a tf.Tensor of shape=(2, 7, 64), dtype=float32)
In order to use the mask
into the MultiHeadAttention
layer, the mask must be reshaped to accomplish with the shape requirements, which per the documentation is [B, T, S]
where B
means the batch size (2 in the example), T
means the query size (7 in our example), and S
means the key size (again 7 if we are using self attention). Also in a multihead attention layer we must take care of the number of heads H
. The easiest way of creating a compatible mask with this input is via broadcasting:
mask = mask[:, tf.newaxis, tf.newaxis, :]
# mask is a tf.Tensor of shape=(2, 1, 1, 7), dtype=bool) -> [B, H, T, S]
Then we can finally feed the MultiHeadAttention
layer as follows:
mha = tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=64)
z = mha(y, y, attention_mask=mask)
So in order to use, your TransformerBlock
layer with a mask, you should add to the call
method a mask
argument, as follows:
def call(self, inputs, training, mask=None):
attn_output = self.att(inputs, inputs, attention_mask=mask)
...
And in the layer/model where you are calling the MultiHeadAttention
layer, you must pass/propagate the mask that you generated with the Embedding
layer.
Upvotes: 12