Reputation: 9869
I am trying to create a search relevance model where I take the dot product between query vector and resulting documents. I add a positional bias term on top to take into account the fact that position 1 is more likely to be clicked on. The final (unnormalised) log likelihood calculation is as follows:
query = self.query_model(query_input_ids, query_attention_mask)
docs = self.doc_model(doc_input_ids, doc_attention_mask)
positional_bias = self.position_model()
if optimizer_idx is not None:
if optimizer_idx == 0:
docs = docs.detach()
positional_bias = positional_bias.clone().detach()
elif optimizer_idx == 1:
query = query.detach()
positional_bias = positional_bias.clone().detach()
else:
query = query.detach()
docs = docs.detach()
similarity = (docs @ query.unsqueeze(-1)).squeeze()
click_log_lik = (similarity + positional_bias)\
.reshape(doc_mask.shape)\
.masked_fill_((1 - doc_mask).bool(), float("-inf"))
The query and doc model is simply a distilbert model with a projection layer on top of CLS token. The models can be seen here: https://pastebin.com/g21g9MG3
When inspecting the first gradient descent step, it has nan
s, but only for the query model and not the doc model. My hypothesis is that normalizing the return values for doc and query models (return F.normalize(out, dim=-1)
) is somehow playing up with the gradients.
Does anyone know 1. If my hypothesis is true and more importantly 2. How can I rectify nan gradients?.
masked_fill
in the last line is because occasionally I have less than 10 data points for a query.The following changes made no difference to nans:
masked_fill
from -inf
to 1e5
.F.normalize(out, dim=-1)
to out / 100
.Upvotes: 0
Views: 418
Reputation: 9869
If it helps anyone, and you come across this while using Transformers this is what I did:
So in the end the bug was due to the fact that I was masking away nan's. Since I had some documents with zero length, the output of the transformer was nan. I was hoping that masked_fill
would fix this problem, but it doesn't. The solution in my case was to only put non-zero length sequences through transformers, and then append with zeros to fill the batch size.
Upvotes: 1