Alexander Soare
Alexander Soare

Reputation: 3257

How to do a masked mean in PyTorch?

This is the forward pass of a bidirectional rnn where I want to take the avg pool of the output features. As you can see, I'm trying to exclude the time steps with a pad token from the calculation.

def forward(self, text):
    # text is shape (B, L)
    embed = self.embed(text)
    rnn_out, _ = self.rnn(embed)  # (B, L, 2*H)
    # Calculate average ignoring the pad token
    with torch.no_grad():
        rnn_out[text == self.pad_token] *= 0
        denom = torch.sum(text != self.pad_token, -1, keepdim=True)
    feat = torch.sum(rnn_out, dim=1) / denom
    feat =  self.dropout(feat)
    return feat

Backpropagation raises an exception because of the line rnn_out[text == self.pad_token] *= 0. Here's what it looks like:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32, 21, 128]], which is output 0 of CudnnRnnBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

What's the correct way to do this?

Note: I know I can do this by doing and/or of the following:

But I want to know if there's a cleaner way not involving those.

Upvotes: 3

Views: 6342

Answers (1)

flawr
flawr

Reputation: 11628

You're modifying a vector in a context where you disable the building of a computational graph (and you modify it inplace using *=), this will wreak havoc on the computation of the gradient. Instead I'd suggest the following:

mask = text != self.pad_token
denom = torch.sum(mask, -1, keepdim=True)
feat = torch.sum(rnn_out * mask.unsqueeze(-1), dim=1) / denom

Maybe you have to tweak this snippet a little bit, I couldn't test it as you haven't provided a complete example, but it hopefully shows the technique you can use.

Upvotes: 5

Related Questions