flawr
flawr

Reputation: 11628

Can pytorch's autograd handle torch.cat?

I'm trying to implement a simple neural network that is supposed to learn an grayscale image. The input consist of the 2d indices of a pixel, the output should be the value of that pixel.

The net is constructed as follows: Each neuron is connected to the input (i.e. the indices of the pixel) as well as to the output of each of the previous neurons. The output is just the output of the last neuron in this sequence.

This kind of network has been very successfull in learning images, as demonstrated e.g. here.

The Problem: In my implementation the loss function stays between 0.2 and 0.4 depending on the number of neurons, the learning rate and the number of iterations used, which is pretty bad. Also if you compare the output to what what we've trained it for there it just looks like noise. But this is the first time I used torch.cat within the network, so I'm not sure whether this is the culprit. Can anyone see what I'm doing wrong?

from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Linear

class My_Net(nn.Module):
    lin: List[Linear]

    def __init__(self):
        super(My_Net, self).__init__()
        self.num_neurons = 10
        self.lin = nn.ModuleList([nn.Linear(k+2, 1) for k in range(self.num_neurons)])

    def forward(self, x):
        v = x
        recent = torch.Tensor(0)
        for k in range(self.num_neurons):
            recent = F.relu(self.lin[k](v))
            v = torch.cat([v, recent], dim=1)
        return recent

    def num_flat_features(self, x):
        size = x.size()[1:]
        num = 1
        for i in size():
            num *= i
        return num

my_net = My_Net()
print(my_net)

#define a small 3x3 image that the net is supposed to learn
my_image = [[1.0, 1.0, 1.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] #represents a T-shape
my_image_flat = []    #output of the net is the value of a pixel
my_image_indices = [] #input to the net is are the 2d indices of a pixel
for i in range(len(my_image)):
    for j in range(len(my_image[i])):
        my_image_flat.append(my_image[i][j])
        my_image_indices.append([i, j])

#optimization loop
for i in range(100):
    inp = torch.Tensor(my_image_indices)

    out = my_net(inp)

    target = torch.Tensor(my_image_flat)
    criterion = nn.MSELoss()
    loss = criterion(out.view(-1), target)
    print(loss)

    my_net.zero_grad()
    loss.backward()
    optimizer = optim.SGD(my_net.parameters(), lr=0.001)
    optimizer.step()

print("output of current image")
print([[my_net(torch.Tensor([[i,j]])).item() for i in range(3)] for j in range(3)])
print("output of original image")
print(my_image)

Upvotes: 1

Views: 2616

Answers (1)

MBT
MBT

Reputation: 24169

Yes, torch.cat is backprob-able. So you use it without problems for this.

What's the problem here is that you define a new optimizer at every iteration. Instead you should define it once after you defined your model.

So having this changed the code works fine and loss is decreasing continuously. I also added a print out every 5000 iterations to show the progress.

from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Linear

class My_Net(nn.Module):
    lin: List[Linear]

    def __init__(self):
        super(My_Net, self).__init__()
        self.num_neurons = 10
        self.lin = nn.ModuleList([nn.Linear(k+2, 1) for k in range(self.num_neurons)])

    def forward(self, x):
        v = x
        recent = torch.Tensor(0)
        for k in range(self.num_neurons):
            recent = F.relu(self.lin[k](v))
            v = torch.cat([v, recent], dim=1)
        return recent

    def num_flat_features(self, x):
        size = x.size()[1:]
        num = 1
        for i in size():
            num *= i
        return num

my_net = My_Net()
print(my_net)

optimizer = optim.SGD(my_net.parameters(), lr=0.001)



#define a small 3x3 image that the net is supposed to learn
my_image = [[1.0, 1.0, 1.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] #represents a T-shape
my_image_flat = []    #output of the net is the value of a pixel
my_image_indices = [] #input to the net is are the 2d indices of a pixel
for i in range(len(my_image)):
    for j in range(len(my_image[i])):
        my_image_flat.append(my_image[i][j])
        my_image_indices.append([i, j])

#optimization loop
for i in range(50000):
    inp = torch.Tensor(my_image_indices)

    out = my_net(inp)

    target = torch.Tensor(my_image_flat)
    criterion = nn.MSELoss()
    loss = criterion(out.view(-1), target)
    if i % 5000 == 0:
        print('Iteration:', i, 'Loss:', loss)

    my_net.zero_grad()
    loss.backward()
    optimizer.step()
print('Iteration:', i, 'Loss:', loss)

print("output of current image")
print([[my_net(torch.Tensor([[i,j]])).item() for i in range(3)] for j in range(3)])
print("output of original image")
print(my_image)

Loss output:

Iteration: 0 Loss: tensor(0.4070)
Iteration: 5000 Loss: tensor(0.1315)
Iteration: 10000 Loss: tensor(1.00000e-02 *
       8.8275)
Iteration: 15000 Loss: tensor(1.00000e-02 *
       5.6190)
Iteration: 20000 Loss: tensor(1.00000e-02 *
       3.2540)
Iteration: 25000 Loss: tensor(1.00000e-02 *
       1.3628)
Iteration: 30000 Loss: tensor(1.00000e-03 *
       4.4690)
Iteration: 35000 Loss: tensor(1.00000e-03 *
       1.3582)
Iteration: 40000 Loss: tensor(1.00000e-04 *
       3.4776)
Iteration: 45000 Loss: tensor(1.00000e-05 *
       7.9518)
Iteration: 49999 Loss: tensor(1.00000e-05 *
       1.7160)

So the loss goes down to 0.000017 in this case. I have to admit that your error surface is really ragged. Depending on the on the initial weights it may also converge to a minimum of 0.17, 0.10 .. etc. The local minimum where it converges can be very different. So you might try initializing your weights within a smaller range.

Btw. here is the output without changing the location of defining the optimizer:

Iteration: 0 Loss: tensor(0.5574)
Iteration: 5000 Loss: tensor(0.5556)
Iteration: 10000 Loss: tensor(0.5556)
Iteration: 15000 Loss: tensor(0.5556)
Iteration: 20000 Loss: tensor(0.5556)
Iteration: 25000 Loss: tensor(0.5556)
Iteration: 30000 Loss: tensor(0.5556)
Iteration: 35000 Loss: tensor(0.5556)
Iteration: 40000 Loss: tensor(0.5556)
Iteration: 45000 Loss: tensor(0.5556)

Upvotes: 6

Related Questions