dorito
dorito

Reputation: 27

How to stop "holes" from showing up in CycleGAN?

so I've been trying to create a CycleGAN PyTorch model from scratch and training it on the original vangogh2photo dataset provided by Berkeley. Admittedly, it's for fun and not for any research, but I still hate it when it doesn't work out. So, this is the architecture of the model I've made:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random

'''This is a memory storage that stores 50 previously created images.
This is in accordance with the paper that introduced CycleGAN, Unpaired Image to Image translation.'''
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                # Returns newly added image with a probability of 0.5.
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[
                        i
                    ] = element  # replaces the older image with the newly generated image.
                else:
                    # Otherwise, it sends an older generated image and
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    

'''Linear learning rate scheduler.'''

class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        if (n_epochs - decay_start_epoch) < 0:
            raise Exception("Decay should start before training ends. Change decay_start_epoch to a value less than {}.".format(n_epochs))
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
    
'''Single Residual Block. InstanceNorm2d produces blob artefacts. Consider changing it to modulated convolutions later.
Currently using augmentation and a low number of epochs to stop Generator from producing artefacts.'''

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super(ResNetBlock, self).__init__()

        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_channels, output_channels, num_resnet_blocks=9):
        super(GeneratorResNet, self).__init__()

        # Initial convolutional layer
        self.initial_conv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=0, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Downsampling layers
        self.downsampling_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.downsampling_2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # Residual layers
        self.residual_layers = nn.Sequential(
            *[ResNetBlock(256) for _ in range(num_resnet_blocks)]
        )

        # Upsampling layers
        self.upsampling_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.upsampling_2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Final convolutional layer
        self.final_conv = nn.Sequential(
            nn.Reflectio
![outputs-2](https://user-images.githubusercontent.com/73417041/231626093-fc0ce59a-7e10-42b0-b38c-ab672f70b0bf.png)
nPad2d(3),
            nn.Conv2d(64, output_channels, kernel_size=7, padding=0, bias=True),
            nn.Tanh()
        )

    def forward(self, x):
        # Apply initial convolutional layer
        x = self.initial_conv(x)

        # Apply downsampling layers
        x = self.downsampling_1(x)
        x = self.downsampling_2(x)

        # Apply residual layers
        x = self.residual_layers(x)

        # Apply upsampling layers
        x = self.upsampling_1(x)
        x = self.upsampling_2(x)

        # Apply final convolutional layer
        x = self.final_conv(x)

        return x

    
'''PatchGAN Discriminator'''

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_channels, out_channels, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # C64 -> C128 -> C256 -> C512
        self.model = nn.Sequential(
            *discriminator_block(channels, out_channels=64, normalize=False),
            *discriminator_block(64, out_channels=128),
            *discriminator_block(128, out_channels=256),
            *discriminator_block(256, out_channels=512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

Now around the third epoch of training, I am getting these "holes" in the generated pictures. Could anyone tell me why these are showing up and how I can prevent it? These are my hyperparameters:

'name': 'CycleGan_VanGogh_Checkpoint', 'n_epochs': 20, 'batch_size': 4, 'lr': 0.0002, 'decay_start_epoch': 19, 'b1': 0.5, 'b2': 0.999, 'img_size': 256, 'channels': 3, 'num_residual_blocks': 9, 'lambda_cyc': 10.0, 'lambda_id': 5.0}

enter image description here

Upvotes: -1

Views: 209

Answers (1)

AVManerikar
AVManerikar

Reputation: 189

This could be a result of using nn.ConvTranspose2D() for upsampling in the generator network. It is known to produce checkerboard distillation artifacts in output images. A good description on how avoid this can be found here: https://distill.pub/2016/deconv-checkerboard/

Upvotes: 2

Related Questions