Reputation: 3561
In this question on Math StackExchange people are discussing the derivative of a function f(x) = Axx'A / (x'AAx)
where x
is a vector and A
is a symmetric, positive semi-definite square matrix.
The derivative of this function at a point x
is a tensor. And when "applied" to another vector h
it is a matrix. The answers under that post differ in terms of expressions for this matrix, so I would like to check them numerically using Pytorch
or Autograd
.
Here is my attempt with Pytorch
import torch
def P(x, A):
x = x.unsqueeze(1) # Convert to column vector
vector = torch.matmul(A, x)
denom = (vector.transpose(0, 1) @ vector).squeeze()
P_matrix = (vector @ vector.transpose(0, 1)) / denom
return P_matrix.squeeze()
A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)
x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)
h = torch.tensor([2.0, -1.0], dtype=torch.float32)
Pxh = torch.matmul(P(x, A), h)
# compute gradient
Pxh.backward()
But this doesn't work. What am I doing wrong?
I am also happy with a Jax Solution. I tried jax.grad
but does not work.
Upvotes: 0
Views: 85
Reputation: 5373
You need to pass a tensor of ones if you want to backprop non-scalar values
import torch
def P(x, A):
x = x.unsqueeze(1) # Convert to column vector
vector = torch.matmul(A, x)
denom = (vector.transpose(0, 1) @ vector).squeeze()
P_matrix = (vector @ vector.transpose(0, 1)) / denom
return P_matrix.squeeze()
A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)
x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)
h = torch.tensor([2.0, -1.0], dtype=torch.float32)
Pxh = torch.matmul(P(x, A), h)
Pxh.backward(torch.ones_like(Pxh))
x.grad
> tensor([ 0.4853, -0.2427])
Upvotes: 1