shmirrkk
shmirrkk

Reputation: 61

How do I manually set partial derivatives for a multi-input function in pytorch?

I am writing a machine learning program for my PhD which finds poles of a rational function which approximates the solution to a given differential equation. To calculate the loss, I need to calculate this estimate which is given as a function with: the poles, the initial condition, and the Hamiltonian matrix of the differential equation as inputs. Pytorch's autograd fails to calculate the partial derivatives of this function correctly, so I must set analytic partial derivatives manually. To do this I have been working from the following example:

https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

This is only an example for a single input function and so I am asking how this could be extended to give the partial derivatives of each input in the backwards pass. In practice I only need the partial derivative with respect to the poles so if there is a way to treat these other inputs as constants and return the pole's gradient in the backwards pass this will be sufficient.

I have tried doing this for a much simpler 2 input function which multiplies the two inputs together, and returns the partial derivative with respect to the first input (ie returns the second input).

import torch
class double_in(torch.autograd.Function):
    @staticmethod
    def forward(ctx,input,constant):
        ctx.save_for_backward(input,constant)
        output=input*constant
        return output
    @staticmethod
    def backward(ctx,grad_output):
        input,constant, = ctx.saved_tensors
        inputgrad=constant
        return torch.mul(grad_output,inputgrad)
    
x=torch.rand(1,requires_grad=True)
out=double_in(x,5)
print("x = ",x," out = ",out)

The code does not even calculate the multiple of the two inputs, and the output given is:

x =  tensor([0.2222], requires_grad=True)  out =  <__main__.dubin object at 0x0000022E87188120>

Upvotes: 2

Views: 119

Answers (0)

Related Questions