AMendis
AMendis

Reputation: 1584

Import transparent images to GAN

I have Images set which has transparency.

I'm trying to train GAN(Generative adversarial networks).

How can I preserve transparency. I can see from output images all transparent area is BLACK.

How can I avoid doing that ?

I think this is called "Alpha Channel".

Anyways How can I keep my transparency ?

Below is my code.

   # Importing the libraries
from __future__ import print_function
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from generator import G
from discriminator import D
import os

batchSize = 64  # We set the size of the batch.
imageSize = 64  # We set the size of the generated images (64x64).
input_vector = 100
nb_epochs = 500
# Creating the transformations
transform = transforms.Compose([transforms.Resize((imageSize, imageSize)), transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5,
                                                                       0.5)), ])  # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.

# Loading the dataset
dataset = dset.ImageFolder(root='./data', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True,
                                         num_workers=2)  # We use dataLoader to get the images of the training set batch by batch.


# Defining the weights_init function that takes as input a neural network m and that will initialize all its weights.
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def is_cuda_available():
    return torch.cuda.is_available()


def is_gpu_available():
    if is_cuda_available():
        if int(torch.cuda.device_count()) > 0:
            return True
        return False
    return False


# Create results directory
def create_dir(name):
    if not os.path.exists(name):
        os.makedirs(name)


# Creating the generator
netG = G(input_vector)
netG.apply(weights_init)

# Creating the discriminator
netD = D()
netD.apply(weights_init)

if is_gpu_available():
    netG.cuda()
    netD.cuda()

# Training the DCGANs

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

generator_model = 'generator_model'
discriminator_model = 'discriminator_model'


def save_model(epoch, model, optimizer, error, filepath, noise=None):
    if os.path.exists(filepath):
        os.remove(filepath)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': error,
        'noise': noise
    }, filepath)


def load_checkpoint(filepath):
    if os.path.exists(filepath):
        return torch.load(filepath)
    return None

def main():
    print("Device name : " + torch.cuda.get_device_name(0))
    for epoch in range(nb_epochs):

        for i, data in enumerate(dataloader, 0):
            checkpointG = load_checkpoint(generator_model)
            checkpointD = load_checkpoint(discriminator_model)
            if checkpointG:
                netG.load_state_dict(checkpointG['model_state_dict'])
                optimizerG.load_state_dict(checkpointG['optimizer_state_dict'])
            if checkpointD:
                netD.load_state_dict(checkpointD['model_state_dict'])
                optimizerD.load_state_dict(checkpointD['optimizer_state_dict'])

            # 1st Step: Updating the weights of the neural network of the discriminator

            netD.zero_grad()

            # Training the discriminator with a real image of the dataset
            real, _ = data
            if is_gpu_available():
                input = Variable(real.cuda()).cuda()
                target = Variable(torch.ones(input.size()[0]).cuda()).cuda()
            else:
                input = Variable(real)
                target = Variable(torch.ones(input.size()[0]))
            output = netD(input)
            errD_real = criterion(output, target)

            # Training the discriminator with a fake image generated by the generator
            if is_gpu_available():
                noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1)).cuda()
                target = Variable(torch.zeros(input.size()[0])).cuda()
            else:
                noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1))
                target = Variable(torch.zeros(input.size()[0]))
            fake = netG(noise)
            output = netD(fake.detach())
            errD_fake = criterion(output, target)

            # Backpropagating the total error
            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            # 2nd Step: Updating the weights of the neural network of the generator
            netG.zero_grad()
            if is_gpu_available():
                target = Variable(torch.ones(input.size()[0])).cuda()
            else:
                target = Variable(torch.ones(input.size()[0]))
            output = netD(fake)
            errG = criterion(output, target)
            errG.backward()
            optimizerG.step()

            # 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, nb_epochs, i, len(dataloader), errD.data, errG.data))
            save_model(epoch, netG, optimizerG, errG, generator_model, noise)
            save_model(epoch, netD, optimizerD, errD, discriminator_model, noise)

            if i % 100 == 0:
                create_dir('results')
                vutils.save_image(real, '%s/real_samples.png' % "./results", normalize=True)
                fake = netG(noise)
                vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize=True)


if __name__ == "__main__":
    main()

generator.py

import torch.nn as nn class G(nn.Module): feature_maps = 512 kernel_size = 4 stride = 2 padding = 1 bias = False

def __init__(self, input_vector):
    super(G, self).__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(input_vector, self.feature_maps, self.kernel_size, 1, 0, bias=self.bias),
        nn.BatchNorm2d(self.feature_maps), nn.ReLU(True),
        nn.ConvTranspose2d(self.feature_maps, int(self.feature_maps // 2), self.kernel_size, self.stride, self.padding,
                           bias=self.bias),
        nn.BatchNorm2d(int(self.feature_maps // 2)), nn.ReLU(True),
        nn.ConvTranspose2d(int(self.feature_maps // 2), int((self.feature_maps // 2) // 2), self.kernel_size, self.stride,
                           self.padding,
                           bias=self.bias),
        nn.BatchNorm2d(int((self.feature_maps // 2) // 2)), nn.ReLU(True),
        nn.ConvTranspose2d((int((self.feature_maps // 2) // 2)), int(((self.feature_maps // 2) // 2) // 2), self.kernel_size,
                           self.stride, self.padding,
                           bias=self.bias),
        nn.BatchNorm2d(int((self.feature_maps // 2) // 2) // 2), nn.ReLU(True),
        nn.ConvTranspose2d(int(((self.feature_maps // 2) // 2) // 2), 4, self.kernel_size, self.stride, self.padding,
                           bias=self.bias),
        nn.Tanh()
    )

def forward(self, input):
    output = self.main(input)
    return output

discriminator.py

import torch.nn as nn
class D(nn.Module):
    feature_maps = 64
    kernel_size = 4
    stride = 2
    padding = 1
    bias = False
    inplace = True

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(4, self.feature_maps, self.kernel_size, self.stride, self.padding, bias=self.bias),
            nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps, self.feature_maps * 2, self.kernel_size, self.stride, self.padding,
                      bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * 2), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * 2, self.feature_maps * (2 * 2), self.kernel_size, self.stride, self.padding,
                      bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * (2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * (2 * 2), self.feature_maps * (2 * 2 * 2), self.kernel_size, self.stride,
                      self.padding, bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * (2 * 2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * (2 * 2 * 2), 1, self.kernel_size, 1, 0, bias=self.bias),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)

Upvotes: 3

Views: 573

Answers (1)

Shai
Shai

Reputation: 114976

Using dset.ImageFolder, without explicitly defining the function that reads the image (the loader) results with your dataset using the default pil_loader:

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

As you can see, the default loader discards the alpha channel and forces the image to be with only three color channels: RGB.

You can define your own loader:

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGBA')  # force alpha channel

You can use this loader in your dataset:

dataset = dset.ImageFolder(root='./data', transform=transform, loader=pil_loader_rgba)

Now your images will have the alpha channel.

Note that the transparency ("alpha channel") is an additional channel and is not part of the RGB channels. You need to make sure your model knows how to handle 4-channel inputs, otherwise, you'll run into errors such as this.

Upvotes: 2

Related Questions