Lester Li
Lester Li

Reputation: 1

How to get grads in pytorch after matrix multiplication?

I want to get the product of matrix multiplication in the latent space and optimize the weight matrix by the optimizer. I use different kinds of ways to do that. While, The value of 'pi_' in the below codes never changes. What should I do?

I've tried different functions to get the product, like torch.mm(), torch.matual() and @. The weight matrix 'pi_' never changed.

import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
#from torchvision import transforms
from torchvision.datasets import MNIST

def get_mnist(data_dir='./data/mnist/',batch_size=128):
    train=MNIST(root=data_dir,train=True,download=True)
    test=MNIST(root=data_dir,train=False,download=True)

    X=torch.cat([train.data.float().view(-1,784)/255.,test.data.float().view(-1,784)/255.],0)
    Y=torch.cat([train.targets,test.targets],0)

    dataset=dict()
    dataset['X']=X
    dataset['Y']=Y

    dataloader=DataLoader(TensorDataset(X,Y),batch_size=batch_size,shuffle=True)

    return dataloader

class tests(torch.nn.Module):
    def __init__(self):
        super(tests, self).__init__()

        self.pi_= torch.nn.Parameter(torch.FloatTensor(10, 1).fill_(1),requires_grad=True)
        self.linear0 = torch.nn.Linear(784,10)
        self.linear1 = torch.nn.Linear(1,784)

    def forward(self, data):
        data = torch.nn.functional.relu(self.linear0(data))
#        data = data.mm(self.pi_)
#        data = torch.mm(data, self.pi_)
#        data = data @ self.pi_
        data = torch.matmul(data, self.pi_)
        data = torch.nn.functional.relu(self.linear1(data))
        return data

if __name__ == '__main__':
    DL=get_mnist()
    t = tests().cuda()
    optimizer = torch.optim.Adam(t.parameters(), lr = 2e-3)

    for i in range(100):
        for inputs, classes in DL:
            inputs = inputs.cuda()

            res = t(inputs)    
            loss = torch.nn.functional.mse_loss(res, inputs)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("Epoch:", i,"pi:",t.pi_)

Upvotes: 0

Views: 732

Answers (1)

Raven Cheuk
Raven Cheuk

Reputation: 3053

TL;DR You have too many parameters in your neural network, some of them becomes useless and therefore they are no longer being updated. Change your network architecture to reduce useless parameters.

Full explanation: The weight matrix pi_ does change. You initialize pi_ as all 1, after running the first epochs, the weight matrix pi_ becomes

output >>>
tensor([[0.9879],
        [0.9874],
        [0.9878],
        [0.9880],
        [0.9876],
        [0.9878],
        [0.9878],
        [0.9873],
        [0.9877],
        [0.9871]], device='cuda:0', requires_grad=True)

So, it has changed once. The true reason behind it involves some mathematics. But to put it in non-mathematical ways it means this layer doesn't contribute much to the loss, therefore, the network decided not to update this layer. i.e. The existence of pi_ in this network is redundant.

If you want to observe the change in pi_, you should modify the neural network such that pi_ is not redundant anymore.

One possible modification is to change your reconstruction problem to a classification problem

import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
#from torchvision import transforms
from torchvision.datasets import MNIST

def get_mnist(data_dir='./data/mnist/',batch_size=128):
    train=MNIST(root=data_dir,train=True,download=True)
    test=MNIST(root=data_dir,train=False,download=True)

    X=torch.cat([train.data.float().view(-1,784)/255.,test.data.float().view(-1,784)/255.],0)
    Y=torch.cat([train.targets,test.targets],0)

    dataset=dict()
    dataset['X']=X
    dataset['Y']=Y

    dataloader=DataLoader(TensorDataset(X,Y),batch_size=batch_size,shuffle=True)

    return dataloader

class tests(torch.nn.Module):
    def __init__(self):
        super(tests, self).__init__()

#         self.pi_= torch.nn.Parameter(torch.randn((10, 1),requires_grad=True))
        self.pi_= torch.nn.Parameter(torch.FloatTensor(10, 1).fill_(1),requires_grad=True)
        self.linear0 = torch.nn.Linear(784,10)
#         self.linear1 = torch.nn.Linear(1,784)

    def forward(self, data):
        data = torch.nn.functional.relu(self.linear0(data))
#        data = data.mm(self.pi_)
#        data = torch.mm(data, self.pi_)
#        data = data @ self.pi_
        data = torch.matmul(data, self.pi_)
#         data = torch.nn.functional.relu(self.linear1(data))
        return data

if __name__ == '__main__':
    DL=get_mnist()
    t = tests().cuda()
    optimizer = torch.optim.Adam(t.parameters(), lr = 2e-3)

    for i in range(100):
        for inputs, classes in DL:
            inputs = inputs.cuda()
            classes = classes.cuda().float()
            output = t(inputs)    
            loss = torch.nn.functional.mse_loss(output.view(-1), classes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
#         print("Epoch:", i, "pi_grad", t.pi_.grad)
        print("Epoch:", i,"pi:",t.pi_)

Now, pi_ changes every single epoch.

output >>>
Epoch: 0 pi: Parameter containing:
tensor([[1.3429],
        [1.0644],
        [0.9817],
        [0.9767],
        [0.9715],
        [1.1110],
        [1.1139],
        [0.9759],
        [1.2424],
        [1.2632]], device='cuda:0', requires_grad=True)
Epoch: 1 pi: Parameter containing:
tensor([[1.4413],
        [1.1977],
        [0.9588],
        [1.0325],
        [0.9241],
        [1.1988],
        [1.1690],
        [0.9248],
        [1.2892],
        [1.3427]], device='cuda:0', requires_grad=True)
Epoch: 2 pi: Parameter containing:
tensor([[1.4653],
        [1.2351],
        [0.9539],
        [1.1588],
        [0.8670],
        [1.2739],
        [1.2058],
        [0.8648],
        [1.2848],
        [1.3891]], device='cuda:0', requires_grad=True)
Epoch: 3 pi: Parameter containing:
tensor([[1.4375],
        [1.2256],
        [0.9580],
        [1.2293],
        [0.8174],
        [1.3471],
        [1.2035],
        [0.8102],
        [1.2505],
        [1.4201]], device='cuda:0', requires_grad=True)

Upvotes: 1

Related Questions