Blade
Blade

Reputation: 1110

PyTorch equivalent of a Tensorflow loss function

I was trying to reimplement a TensorFlow code using PyTorch framework. Below I have included the TF sample code and my PyT interpretation, for a target of size (Batch, 9, 9, 4) and a network output of size (Batch, 9, 9, 4)

TensorFlow implementation:

loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
loss = tf.matrix_band_part(loss, 0, -1) - tf.matrix_band_part(loss, 0, 0)

PyTorch implementation:

output = torch.tensor(output, requires_grad=True).view(-1, 4)
target = torch.tensor(target).view(-1, 4).argmax(1)

loss = torch.nn.CrossEntropyLoss(reduction='none')
my_loss = loss(output, target).view(-1,9,9)

For the PyTorch implementation, I'm not sure how to implement tf.matrix_band_part. I was thinking about defining a mask, but I was not sure if that would hurt the backpropagation or not. I am aware of torch.triu, but this function does not work for tensors with more than 2 dimensions.

Upvotes: 1

Views: 703

Answers (1)

Grigory Feldman
Grigory Feldman

Reputation: 413

Since (at least) version 1.2.0 torch.triu works with batches well (as per docs).

You can get diagonal elements via einsum: torch.einsum('...ii->...i', A).

Applying mask doesn't hurt backprop. You can think about it as projection (which is obviously works well with backprop).

Upvotes: 1

Related Questions