Mstislaw
Mstislaw

Reputation: 153

How to reshape multichannel image with a PyTorch encoder?

I have a tensor with dimensions [18, 512, 512], representing grayscale heatmaps for up to 18 specific objects on a 512x512 image. In order to generate a suitable representation of this image for my conditional GAN, I need to reshape this tensor into a [512, 4, 4] shape using an encoder. However, I can't understand how this transformation can be achieved, since given dimensions apppear too mismatched for direct linear or convolutional transformations.

class HeatmapEncoder(torch.nn.Module):
  def __init__(self):
    # source = 18x512x512
    # target = 512x4x4
    self.encoder = torch.nn.Sequential(
        nn.Linear(),
        nn.ReLU(),
        nn.Linear()
    )

    def forward(self, x):
      pass

It is possible to use `nn.Flatten()' with start_dim=0 here, but the result will be a simplified tensor that can't be used as an input by the linear layer. The decoder part is not especially important right now, since I only need a low-dimensional representation of the heatmaps to condition my GAN and not to recreate those images.

Upvotes: 0

Views: 633

Answers (1)

yutasrobot
yutasrobot

Reputation: 2496

You can try a couple of different approaches for your problem, like viewing it as a one big vector then slowly reducing it to your size, or permuting dimensions and applying different operations etc. Since there is no full code for testing it on the actual problem, I can't really say which will work better, but my first instinct says convolution based dimension reductions is quite suitable for this problem. First code, then talk:

class ReduceConv(torch.nn.Module):
  def __init__(self, nin, nout, activ=nn.ReLU):
    super(ReduceConv, self).__init__()

    # source = Batch x nin x H x W
    # target = Batch x nout x (H/2) x (W/2)

    self.conv = nn.Sequential(
        nn.Conv2d(
            nin, nout,
            kernel_size=3,
            stride=1,
            padding=1),
        nn.Conv2d(
            nout, nout,
            kernel_size=3,
            stride = 2,
            padding = 1),
        nn.BatchNorm2d(nout),
        activ()
    )

  def forward(self, x):
    return self.conv(x)

class HeatmapEncoder(torch.nn.Module):
  def __init__(self):
    super(HeatmapEncoder, self).__init__()

    # source = 18x512x512
    # target = 512x4x4
    self.encoder = torch.nn.Sequential(
        ReduceConv(18, 32),       # out->  32 256 256 
        ReduceConv(32, 64),       # out->  64 128 128
        ReduceConv(64, 64),       # out->  64 64 64
        ReduceConv(64, 64),       # out->  64 32 32
        ReduceConv(64, 128),      # out-> 128 16 16
        ReduceConv(128, 256),     # out-> 256 8 8
        ReduceConv(256, 512)      # out-> 512 4 4
    )

  def forward(self, x):
    return self.encoder(x)

# 10 is batch size
inp = torch.rand(10, 18, 512, 512)
enc = HeatmapEncoder()
out = enc(inp)
print(inp.shape)      # torch.Size([10, 18, 512, 512])
print(out.shape)      # torch.Size([10, 512, 4, 4]) 

It is essentially just a stack of convolution layers. Note that, in each ReduceConv layer input dimensions are halved by using stride=2 convolutions. You don't technically need the first convolution in the ReduceConv layer, but it is deep learning era, more the merrier :) I've also added BatchNorm after each reduction along with an activation function. 18 is regarded as channels when inputted to first convolution. This way channels build up to 512 while the width and height is halved after each operation. This encoder model probably not the best or most efficient one but it should be good enough for your problem.

Upvotes: 1

Related Questions