yun_91D
yun_91D

Reputation: 1

Perceptual_loss function _ diffusion model

I tried to compute the perceptual loss function between the generated and the groundtruth image in the diffusion model (I used it for image to image translation, images are in gray scale). This is the code of loss function:

from .custom_loss import common

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class VGGPerceptualLoss(nn.Module):
    def __init__(self, conv_index, rgb_range=1):
        super(VGGPerceptualLoss, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m for m in vgg_features]
        if conv_index.find('22') >= 0:
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index.find('54') >= 0:
            self.vgg = nn.Sequential(*modules[:35])

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, sr, hr):
        def _forward(x):
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x
            
        vgg_sr = _forward(sr[0])
        with torch.no_grad():
            vgg_hr = _forward(hr[0].detach())

        loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss

And this is the part of the diffusion model:

    def forward_backward(self, batch, cond):
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }

            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            compute_losses = functools.partial(
                self.diffusion.training_losses_segmentation,
                self.ddp_model,
                self.classifier,
                self.prior,
                self.posterior,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses1 = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses1 = compute_losses()

            

        
            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach())

            losses = losses1[0]
            sample = losses1[1] 
            conv_index='22'
            perceptual_loss_fn = VGGPerceptualLoss(conv_index=conv_index) 
            perceptual_loss = perceptual_loss_fn(sample, batch)
            loss = (losses["loss"] * weights).mean()+(perceptual_loss * self.perceptual_loss_weight)
            # tensor_losses = torch.cat((losses["loss"],perceptual_loss))
            # loss = (tensor_losses* weights).mean()
            # print(f'weights: {weights}')
            # print(f'loss: {losses["loss"]}')
            # loss = (losses["loss"] * weights).mean()
            lossseg = (losses["mse"] * weights).mean().detach()
            losscls = (losses["vb"] * weights).mean().detach()
            lossrec = loss * 0

            log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()})
            self.mp_trainer.backward(loss)
            return lossseg.detach(), losscls.detach(), lossrec.detach(), sample

This is the error: Please how can I fix it?

/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=VGG19_Weights.IMAGENET1K_V1. You can also use weights=VGG19_Weights.DEFAULT to get the most up-to-date weights. warnings.warn(msg) Traceback (most recent call last): File "/content/drive/MyDrive/seg/seg/scripts/segmentation_train.py", line 147, in main() File "/content/drive/MyDrive/seg/seg/scripts/segmentation_train.py", line 118, in main ).run_loop() File "/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py", line 195, in run_loop self.run_step(batch, cond) File "/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py", line 220, in run_step lossseg, losscls, lossrec, sample = self.forward_backward(batch, cond) File "/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py", line 267, in forward_backward perceptual_loss = perceptual_loss_fn(sample, batch) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/content/drive/MyDrive/seg/seg/./guided_diffusion/vgg.py", line 30, in forward vgg_sr = _forward(sr[0]) File "/content/drive/MyDrive/seg/seg/./guided_diffusion/vgg.py", line 26, in _forward x = self.sub_mean(x) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward return self._conv_forward(input, self.weight, self.bias) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[1, 1, 128, 128] to have 3 channels, but got 1 channels instead

Upvotes: 0

Views: 370

Answers (1)

Chih-Hao Liu
Chih-Hao Liu

Reputation: 466

In your VGGPerceptualLoss,

vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)

the function can only accepts the RGB channel.

However, in your error message,

RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[1, 1, 128, 128] to have 3 channels, but got 1 channels instead

you input the gray-scale image with the tensor size (1, 1, 128, 128).

Therefore, you may use the following code to open your gray-scale image data,

from PIL import Image
img = Image.open("[your training image]").convert("RGB")

Your gray-scale image can be transformed to RGB format with 3 channel.

Upvotes: 1

Related Questions