Abo Omar
Abo Omar

Reputation: 135

Plot the derivative of a function with PyTorch?

I have this code:

import torch
import matplotlib.pyplot as plt  
x=torch.linspace(-10, 10, 10, requires_grad=True)
y = torch.sum(x**2)
y.backward()
plt.plot(x.detach().numpy(), y.detach().numpy(), label='function')
plt.legend()

But, I got this error:

ValueError: x and y must have same first dimension, but have shapes (10,) and (1,)

Upvotes: 2

Views: 1566

Answers (2)

MBT
MBT

Reputation: 24169

I think the main problem is that your dimensions do not match. Why do you wan't to use torch.sum?

This should work for you:

# %matplotlib inline added this line only for jupiter notebook
import torch
import matplotlib.pyplot as plt  
x = torch.linspace(-10, 10, 10, requires_grad=True)

y = x**2      # removed the sum to stay with the same dimensions
y.backward(x) # handing over the parameter x, as y isn't a scalar anymore
# your function
plt.plot(x.detach().numpy(), y.detach().numpy(), label='x**2')
# gradients
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='grad')
plt.legend()

You get a nicer picture though with more steps, I also changed the interval a bit to torch.linspace(-2.5, 2.5, 50, requires_grad=True).

Edit regarding comment:

This version plots you the gradients with torch.sum included:

# %matplotlib inline added this line only for jupiter notebook
import torch
import matplotlib.pyplot as plt  
x = torch.linspace(-10, 10, 10, requires_grad=True)

y = torch.sum(x**2) 
y.backward() 
print(x.grad)
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='grad')
plt.legend()

Output:

tensor([-20.0000, -15.5556, -11.1111,  -6.6667,  -2.2222,   2.2222,
      6.6667,  11.1111,  15.5556,  20.0000])

Plot:

enter image description here

Upvotes: 4

Harshit Kumar
Harshit Kumar

Reputation: 12857

I'm assuming you want to plot the graph of the derivative of x**2.

Then, you need to plot the graph between x and x.grad NOT x and y i.e.

plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='function').

Upvotes: 0

Related Questions