Euler_Salter
Euler_Salter

Reputation: 3561

Checking derivative tensor in Pytorch

In this question on Math StackExchange people are discussing the derivative of a function f(x) = Axx'A / (x'AAx) where x is a vector and A is a symmetric, positive semi-definite square matrix.

The derivative of this function at a point x is a tensor. And when "applied" to another vector h it is a matrix. The answers under that post differ in terms of expressions for this matrix, so I would like to check them numerically using Pytorch or Autograd.

Here is my attempt with Pytorch

import torch 

def P(x, A):
    x = x.unsqueeze(1)  # Convert to column vector
    vector = torch.matmul(A, x)
    denom = (vector.transpose(0, 1) @ vector).squeeze()
    P_matrix = (vector @ vector.transpose(0, 1)) / denom
    return P_matrix.squeeze()

A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)
x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)
h = torch.tensor([2.0, -1.0], dtype=torch.float32)

Pxh = torch.matmul(P(x, A), h)

# compute gradient 
Pxh.backward()

But this doesn't work. What am I doing wrong?

JAX

I am also happy with a Jax Solution. I tried jax.grad but does not work.

Upvotes: 0

Views: 85

Answers (1)

Karl
Karl

Reputation: 5373

You need to pass a tensor of ones if you want to backprop non-scalar values

import torch 

def P(x, A):
    x = x.unsqueeze(1)  # Convert to column vector
    vector = torch.matmul(A, x)
    denom = (vector.transpose(0, 1) @ vector).squeeze()
    P_matrix = (vector @ vector.transpose(0, 1)) / denom
    return P_matrix.squeeze()

A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)
x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)
h = torch.tensor([2.0, -1.0], dtype=torch.float32)

Pxh = torch.matmul(P(x, A), h)

Pxh.backward(torch.ones_like(Pxh))

x.grad
> tensor([ 0.4853, -0.2427])

Upvotes: 1

Related Questions