sachinruk
sachinruk

Reputation: 9869

Getting nans for gradient

I am trying to create a search relevance model where I take the dot product between query vector and resulting documents. I add a positional bias term on top to take into account the fact that position 1 is more likely to be clicked on. The final (unnormalised) log likelihood calculation is as follows:

        query = self.query_model(query_input_ids, query_attention_mask)
        docs = self.doc_model(doc_input_ids, doc_attention_mask)
        positional_bias = self.position_model()
        
        if optimizer_idx is not None:
            if optimizer_idx == 0:
                docs = docs.detach()
                positional_bias = positional_bias.clone().detach()
            elif optimizer_idx == 1:
                query = query.detach()
                positional_bias = positional_bias.clone().detach()
            else:
                query = query.detach()
                docs = docs.detach()
                
        similarity = (docs @ query.unsqueeze(-1)).squeeze()

        click_log_lik = (similarity + positional_bias)\
                .reshape(doc_mask.shape)\
                .masked_fill_((1 - doc_mask).bool(), float("-inf"))

The query and doc model is simply a distilbert model with a projection layer on top of CLS token. The models can be seen here: https://pastebin.com/g21g9MG3

When inspecting the first gradient descent step, it has nans, but only for the query model and not the doc model. My hypothesis is that normalizing the return values for doc and query models (return F.normalize(out, dim=-1)) is somehow playing up with the gradients.

Does anyone know 1. If my hypothesis is true and more importantly 2. How can I rectify nan gradients?.

Additional Info:

Update 1

The following changes made no difference to nans:

Upvotes: 0

Views: 418

Answers (1)

sachinruk
sachinruk

Reputation: 9869

If it helps anyone, and you come across this while using Transformers this is what I did:

So in the end the bug was due to the fact that I was masking away nan's. Since I had some documents with zero length, the output of the transformer was nan. I was hoping that masked_fill would fix this problem, but it doesn't. The solution in my case was to only put non-zero length sequences through transformers, and then append with zeros to fill the batch size.

Upvotes: 1

Related Questions