Bryce Ramgovind
Bryce Ramgovind

Reputation: 3257

Compute gradient between a scalar and vector in PyTorch

I am trying to replicate code which was written using Theano, to PyTorch. In the code, the author computes the gradient using

import theano.tensor as T    
gparams = T.grad(cost, params)

and the shape of gparams is (256, 240)

I have tried using backward() but it doesn't seem to return anything. Is there an equivalent to grad within PyTorch?

Assume this is my input,

import torch
from torch.autograd import Variable 
cost = torch.tensor(1.6019)
params = Variable(torch.rand(1, 73, 240))

Upvotes: 0

Views: 903

Answers (1)

jodag
jodag

Reputation: 22184

cost needs to be a result of an operation involving params. You can't compute a gradient just knowing the values of two tensors. You need to know the relationship as well. This is why pytorch builds a computation graph when you perform tensor operations. For example, say the relationship is

cost = torch.sum(params)

then we would expect the gradient of cost with respect to params to be a vector of ones regardless of the value of params.

That could be computed as follows. Notice that you need to add the requires_grad flag to indicate to pytorch that you want backward to update the gradient when called.

# Initialize independent variable. Make sure to set requires_grad=true.
params = torch.tensor((1, 73, 240), requires_grad=True)

# Compute cost, this implicitly builds a computation graph which records
# how cost was computed with respect to params.
cost = torch.sum(params)

# Zero the gradient of params in case it already has something in it.
# This step is optional in this example but good to do in practice to
# ensure you're not adding gradients to existing gradients.
if params.grad is not None:
    params.grad.zero_()

# Perform back propagation. This is where the gradient is actually
# computed. It also resets the computation graph.
cost.backward()

# The gradient of params w.r.t to cost is now stored in params.grad.
print(params.grad)

Result:

tensor([1., 1., 1.])

Upvotes: 1

Related Questions