Nathan Montanez
Nathan Montanez

Reputation: 9

BERT fine tuned transformer for chat bot not meeting expected performance

Question:

I've been working on a project using a transformer with a pre-trained BERT encoder for a Seq2Seq task. The goal is to create a chatbot that has access to my computer, allowing me to control my PC with my voice. However, I've noticed that the language model's performance is not meeting my expected levels and is constantly either overfitting or just getting bad results overall. I'm seeking guidance on potential issues or new approaches to this project. One issue is that when I have queries where segments of the input should be used in the output(seen in the example below), the model either copies what it saw last in the dataset or produces something random that it saw in the dataset.
e.g.
Query: Add bread to my shopping list
Response: Alright /uShoppingList'water'
Query: Lower the volume by sixteen
Response: Okay /uVolume'decrease four'

I'm worried that if I use data augmentations to grow the size of the dataset, it will ignore the other training samples like turning lights off or getting the temperature. Instead, I have tried randomly replacing shopping items(for example) with another item from a list of replacements each time getitem in the dataloader is called. With this approach, the size of the dataset stays the same but I don't know if the model will get a representation of how to do these tasks. Thus, it still fails to learn to take what was in the input and use it in the output. It instead uses random things as seen in the previous example.

Details:

GitHub repository: https://github.com/NateTheGreat7117/Jarvis/tree/main

  1. I made my own custom dataset where controls over my computer are embedded into the outputs of the chatbot's response. The dataset can be found on the GitHub repository in a file called conversation.txt. The query-response pairs are split between separate conversations in case I want to have the transformer look at previous queries when generating the response like a normal conversation. The other files in the repository are just replacements used for data augmentations.

Here is an example:
User: Lower the volume by sixteen
Chatbot: Okay /uVolume'decrease sixteen'

  1. I'm using a Seq2Seq transformer architecture created in PyTorch with attention layers, residual connections, and positional embedding. The encoder is a pre-trained BERT variant from HuggingFace that is being fine-tuned separately(with a different optimizer at a lower learning rate) from the decoder.

Model: The model can be found in the /NeuralNetworks/Jarvis(bert encoder + Seq2Seq).ipynb file in my GitHub repository.

def positional_encoding(length, depth):
    depth = depth/2

    positions = torch.unsqueeze(torch.arange(length), 1)
    depths = torch.unsqueeze(torch.arange(depth), 0)/depth

    angle_rates = 1 / (10000**depths)         # (1, depth)
    angle_rads = positions * angle_rates      # (pos, depth)

    pos_encoding = torch.cat(
      [torch.sin(angle_rads), torch.cos(angle_rads)],
      axis=-1) 

    return pos_encoding.to(device, dtype=torch.float32)


class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        # The positional encoding is used to introduce sequence to a sentence by causing words near 
        # eachother to have similar vectors
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)

    def forward(self, x):
        length = np.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positonal_encoding.
        x *= math.sqrt(torch.tensor(self.d_model).type(torch.float32))
        x = x + torch.unsqueeze(self.pos_encoding, 0)[:, :length]
        return x

# Attention
class BaseAttention(nn.Module):
    def __init__(self, d_model, **kwargs):
        super().__init__()
        self.num_heads = kwargs.get('num_heads')
        self.mha = nn.MultiheadAttention(**kwargs)
        self.layernorm = nn.LayerNorm(d_model)
class CrossAttention(BaseAttention):
    def forward(self, x, context):
        x_ = x.permute(1, 0, 2)
        context_ = context.permute(1, 0, 2)
        attn_output, attn_scores = self.mha(
            query=x_,
            key=context_,
            value=context_,
            need_weights=True)
        attn_output = attn_output.permute(1, 0, 2)
        attn_scores = attn_scores.permute(1, 0, 2)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x =x + attn_output
        x = self.layernorm(x)

        return x


class CausalSelfAttention(BaseAttention):
    def forward(self, x):
        x_ = x.permute(1, 0, 2)
        attention_mask = nn.Transformer.generate_square_subsequent_mask(x_.shape[0]).to(device)
        attention_mask = attention_mask.expand(x_.shape[1]*self.num_heads, -1, -1).to(device)
        
        attn_output = self.mha(
            query=x_,
            value=x_,
            key=x_,
            attn_mask=attention_mask,
            is_causal=True)[0]
        attn_output = attn_output.permute(1, 0, 2)
        x = x + attn_output
        x = self.layernorm(x)
        return x
    
sample_csa = CausalSelfAttention(d_model=256, embed_dim=128, 
                                 num_heads=2, kdim=256)

# Encoder
class FeedForward(nn.Module):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate)
        ).to(device)
        self.layer_norm = nn.LayerNorm(d_model).to(device)
        
    def forward(self, x):
        x = x + self.seq(x)
        x = self.layer_norm(x)
        return x
    

class Encoder(nn.Module):
    def __init__(self, *, emb_size, d_model, dff,
                   dropout_rate=0.1):
        super(Encoder, self).__init__()

        self.bert_encoder = bert_encoder
        
        self.ffn = FeedForward(emb_size, dff)
        self.linear = nn.Linear(emb_size, d_model)

    def forward(self, x):
        input_tensor, input_type, input_mask = x
        x = self.bert_encoder(input_tensor, input_type, input_mask).last_hidden_state
        x = self.ffn(x)
        x = self.linear(x)
        return x

#Decoder
class DecoderLayer(nn.Module):
    def __init__(self,
                   *,
                   d_model,
                   num_heads,
                   dff,
                   dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(
            d_model=d_model,
            embed_dim=d_model,
            num_heads=num_heads,
            kdim=d_model,
            dropout=dropout_rate).to(device)
        
        self.cross_attention = CrossAttention(
            d_model=d_model,
            embed_dim=d_model,
            num_heads=num_heads,
            kdim=d_model,
            dropout=dropout_rate).to(device)

        self.ffn = FeedForward(d_model, dff)

    def forward(self, x, context):
        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x

class Decoder(nn.Module):
    def __init__(self, *, emb_size, num_layers, d_model, num_heads, dff, vocab_size,
                   dropout_rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.linear = nn.Linear(emb_size, d_model)
        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 d_model=d_model).to("cuda")
        self.dropout = nn.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(d_model=d_model, num_heads=num_heads,
                         dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dec_layers = nn.ModuleList(self.dec_layers)

        self.last_attn_scores = None
        
        self.final_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x, context):
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x  = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores
        logits = self.final_layer(x)
        
        return logits

Performance Metrics: Most of the changes I make to the transformer only lead to it overfitting. The best metrics I am getting have a loss of 0.2 with PyTorch's NLLLoss, and a validation loss of around 0.55. With these metrics, I can consistently get good responses to simple queries such as asking for the time or weather. However, when I ask for anything related to changing the volume of my computer, adding items to my shopping list, or entering a Wikipedia query into a separate neural network, the output has the right format but uses random volumes/shopping items/Wikipedia queries(just some examples, there are other circumstances like these).

Questions:

What could be potential reasons for the suboptimal performance?
Are there specific considerations or modifications I should make when using a pre-trained BERT encoder in a transformer for Seq2Seq tasks?
How can I diagnose and address performance issues in transformer models?
What other options do I have to make this work?

Upvotes: 0

Views: 330

Answers (1)

LuCiFeR
LuCiFeR

Reputation: 1

I'm new to transformer architecture, but it seems that you might be lacking data for fine-tuning BERT, which could be why generalization is not going well. Alternatively, it could be that your hyperparameters are not correct.

Upvotes: 0

Related Questions