sealpuppy
sealpuppy

Reputation: 633

Use Pytorch SSIM loss function in my model

I am trying out this SSIM loss implement by this repo for image restoration.

For the reference of original sample code on author's GitHub, I tried:

model.train()
for epo in range(epoch):
    for i, data in enumerate(trainloader, 0):
        inputs = data
        inputs = Variable(inputs)
        optimizer.zero_grad()
        inputs = inputs.view(bs, 1, 128, 128)
        top = model.upward(inputs)
        outputs = model.downward(top, shortcut = True)
        outputs = outputs.view(bs, 1, 128, 128)

        if i % 20 == 0:
            out = outputs[0].view(128, 128).detach().numpy() * 255
            cv2.imwrite("/home/tk/Documents/recover/SSIM/" + str(epo) + "_" + str(i) + "_re.png", out)

        loss = - criterion(inputs, outputs)
        ssim_value = - loss.data.item()
        print (ssim_value)
        loss.backward()
        optimizer.step()

However, the results didn't come out as I expected. After first 10 epochs, the printed outcome image were all black.

loss = - criterion(inputs, outputs) is proposed by the author, however, for classical Pytorch training code this will be loss = criterion(y_pred, target), therefore should be loss = criterion(inputs, outputs) here.

However, I tried loss = criterion(inputs, outputs) but the results are still the same.

Can anyone share some thoughts about how to properly utilize SSIM loss? Thanks.

Upvotes: 4

Views: 30583

Answers (2)

Donshel
Donshel

Reputation: 624

The usual way to transform a similarity (higher is better) into a loss is to compute 1 - similarity(x, y).

To create this loss you can create a new "function".

def ssim_loss(x, y):
    return 1. - ssim(x, y)

Alternatively, if the similarity is a class (nn.Module), you can overload it to create a new one.

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

Also, there are better implementations of SSIM than the one of this repo. For example, the one of the piqa Python package is faster. The package can be installed with

pip install piqa

For your problem

from piqa import SSIM

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

criterion = SSIMLoss() # .cuda() if you need GPU support

...
loss = criterion(x, y)
...

should work well.

Upvotes: 4

Kinal
Kinal

Reputation: 31

The author is trying to maximize the SSIM value. The natural understanding of the pytorch loss function and optimizer working is to reduce the loss. But the SSIM value is quality measure and hence higher the better. Hence the author uses
loss = - criterion(inputs, outputs)

You can instead try using
loss = 1 - criterion(inputs, outputs)
as described in this paper.


Modified code (max_ssim.py) for testing the above thing using this repo

import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np

npImg1 = cv2.imread("einstein.png")

img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()


img1 = Variable( img1,  requires_grad=False)
img2 = Variable( img2, requires_grad = True)

print(img1.shape)
print(img2.shape)
# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = 1-pytorch_ssim.ssim(img1, img2).item()
print("Initial ssim:", ssim_value)

# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()

optimizer = optim.Adam([img2], lr=0.01)

while ssim_value > 0.05:
    optimizer.zero_grad()
    ssim_out = 1-ssim_loss(img1, img2)
    ssim_value = ssim_out.item()
    print(ssim_value)
    ssim_out.backward()
    optimizer.step()
    cv2.imshow('op',np.transpose(img2.cpu().detach().numpy()[0],(1,2,0)))
    cv2.waitKey()

Upvotes: 3

Related Questions