Mohit Lamba
Mohit Lamba

Reputation: 1403

Correct way to compute VGG features for Perceptual loss

While computing VGG Perceptual loss, although I have not seen, I feel it is alright to wrap the computation of VGG features for the GT image inside torch.no_grad().

So basically I feel the following will be alright,

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

# Now compute L1 loss

or one should use,

gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)

In both approaches requires_grad for VGG parameters is set False and VGG put in eval() mode.

The first approach will save a lot of GPU resources and feel should be numerically equal to the second one as no backpropagation is required through GT images. But in most implementations, I find the second approach used for computing VGG perceptual loss.

So which option we should go for implementing VGG perceptual loss in PyTorch?

Upvotes: 4

Views: 1029

Answers (1)

Shai
Shai

Reputation: 114926

The first way:

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

Although VGG is in eval mode and its parameters are kept fix, you still need to propagate gradients through it from the loss on the features to the output nw_op. However, there's no reason to compute these gradients w.r.t gt.

Upvotes: 2

Related Questions