Stod
Stod

Reputation: 83

tensorflow.keras.layers.MultiHeadAttention warning that query layer is destroying mask

I am building a transformer model using tensorflow==2.16.1 and one of the layers is a tensorflow.keras.layers.MultiHeadAttention layer.

I implement the attention layer in the TransformerBlock below:

# Import TensorFlow and Keras for building and training neural network models
import tensorflow as tf
from tensorflow.keras.layers import (
    Dense,
    LayerNormalization,
    MultiHeadAttention,
    Dropout,
)

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):

        super(TransformerBlock, self).__init__(**kwargs)

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate

        
        self.att = None  
        self.ffn = None  

        self.layernorm1 = None 
        self.layernorm2 = None 

        self.dropout1 = None
        self.dropout2 = None

    def build(self, input_shape):

        self.att = MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim)
        
        self.ffn = tf.keras.Sequential(
            [Dense(self.ff_dim, activation="relu"), Dense(self.embed_dim)]
        )

        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(self.rate)
        self.dropout2 = Dropout(self.rate)

        super(TransformerBlock, self).build(input_shape)

    def call(self, inputs, training, padding_mask=None, causal_mask=True, qa=False):

        mask = None

        seq_len = tf.shape(inputs)[1]
        batch_size = tf.shape(inputs)[0]

        if padding_mask is not None:
            padding_mask_reshaped = tf.cast(
                tf.reshape(padding_mask, (batch_size, 1, seq_len)), dtype=tf.float32
            )
            mask = tf.broadcast_to(
                padding_mask_reshaped, (batch_size, seq_len, seq_len)
            )


        attn_output = self.att(
            inputs, inputs, attention_mask=mask, use_causal_mask=True
        )

        attn_output = self.dropout1(attn_output, training=training)

        out1 = self.layernorm1(inputs + attn_output)

        ffn_output = self.ffn(out1)

        ffn_output = self.dropout2(ffn_output, training=training)

        out2 = self.layernorm2(out1 + ffn_output)

        return out2
        

Whenever I implement this TransformerBlock I receive a warning.

lib/python3.11/site-packages/keras/src/layers/layer.py:877: UserWarning: Layer 'value' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.

However, when I pass a padding_mask and use_causal_mask=True, it changes the model performance. For example, if I pass use_causal_mask=False the model performs unrealistically well, as predicted if there is no causal mask, which implies to me that the causal mask is working. This same behavior is observed if I create and merge the causal_mask with the padding_mask and pass it through the attention_mask arg.

When I search the internet to see why I am getting this warning there is very little information on it. Does anyone here know why I can't stop getting this warning and what it means?

Upvotes: 2

Views: 156

Answers (1)

Jenny
Jenny

Reputation: 31

Please have a look at this similar issue for your reference. There are some difference observed on this MultiHeadAttention layer implementation for the custom layer between Tensorflow 2.14 and TensorFlow 2.16 version.

However, you can refer to the Image captioning with visual attention example notebook for the MultiHeadAttention layer implementation in TransformerBlock.

Upvotes: 0

Related Questions