Reputation: 1
I tried to compute the perceptual loss function between the generated and the groundtruth image in the diffusion model (I used it for image to image translation, images are in gray scale). This is the code of loss function:
from .custom_loss import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class VGGPerceptualLoss(nn.Module):
def __init__(self, conv_index, rgb_range=1):
super(VGGPerceptualLoss, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index.find('22') >= 0:
self.vgg = nn.Sequential(*modules[:8])
elif conv_index.find('54') >= 0:
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.parameters():
p.requires_grad = False
def forward(self, sr, hr):
def _forward(x):
x = self.sub_mean(x)
x = self.vgg(x)
return x
vgg_sr = _forward(sr[0])
with torch.no_grad():
vgg_hr = _forward(hr[0].detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
And this is the part of the diffusion model:
def forward_backward(self, batch, cond):
self.mp_trainer.zero_grad()
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
compute_losses = functools.partial(
self.diffusion.training_losses_segmentation,
self.ddp_model,
self.classifier,
self.prior,
self.posterior,
micro,
t,
model_kwargs=micro_cond,
)
if last_batch or not self.use_ddp:
losses1 = compute_losses()
else:
with self.ddp_model.no_sync():
losses1 = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach())
losses = losses1[0]
sample = losses1[1]
conv_index='22'
perceptual_loss_fn = VGGPerceptualLoss(conv_index=conv_index)
perceptual_loss = perceptual_loss_fn(sample, batch)
loss = (losses["loss"] * weights).mean()+(perceptual_loss * self.perceptual_loss_weight)
# tensor_losses = torch.cat((losses["loss"],perceptual_loss))
# loss = (tensor_losses* weights).mean()
# print(f'weights: {weights}')
# print(f'loss: {losses["loss"]}')
# loss = (losses["loss"] * weights).mean()
lossseg = (losses["mse"] * weights).mean().detach()
losscls = (losses["vb"] * weights).mean().detach()
lossrec = loss * 0
log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()})
self.mp_trainer.backward(loss)
return lossseg.detach(), losscls.detach(), lossrec.detach(), sample
This is the error: Please how can I fix it?
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or None
for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=VGG19_Weights.IMAGENET1K_V1
. You can also use weights=VGG19_Weights.DEFAULT
to get the most up-to-date weights.
warnings.warn(msg)
Traceback (most recent call last):
File "/content/drive/MyDrive/seg/seg/scripts/segmentation_train.py", line 147, in
main()
File "/content/drive/MyDrive/seg/seg/scripts/segmentation_train.py", line 118, in main
).run_loop()
File "/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py", line 195, in run_loop
self.run_step(batch, cond)
File "/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py", line 220, in run_step
lossseg, losscls, lossrec, sample = self.forward_backward(batch, cond)
File "/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py", line 267, in forward_backward
perceptual_loss = perceptual_loss_fn(sample, batch)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/content/drive/MyDrive/seg/seg/./guided_diffusion/vgg.py", line 30, in forward
vgg_sr = _forward(sr[0])
File "/content/drive/MyDrive/seg/seg/./guided_diffusion/vgg.py", line 26, in _forward
x = self.sub_mean(x)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[1, 1, 128, 128] to have 3 channels, but got 1 channels instead
Upvotes: 0
Views: 370
Reputation: 466
In your VGGPerceptualLoss
,
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
the function can only accepts the RGB channel.
However, in your error message,
RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[1, 1, 128, 128] to have 3 channels, but got 1 channels instead
you input the gray-scale image with the tensor size (1, 1, 128, 128)
.
Therefore, you may use the following code to open your gray-scale image data,
from PIL import Image
img = Image.open("[your training image]").convert("RGB")
Your gray-scale image can be transformed to RGB format with 3 channel.
Upvotes: 1