John_maddon
John_maddon

Reputation: 152

JAX - Problem in differentiating of function

I am trying to perform a Montecarlo Simulation on a call and after that compute in Python its first derivative with respect to the underlying asset, but it still does not works

from jax import random
from jax import jit, grad, vmap
import jax.numpy as jnp

xi = jnp.linspace(1,1.2,5)
def Simulation(xi):
    K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
    S = jnp.broadcast_to(xi,(number_sim,len(xi))).T

    mean = -.5 * sigma * sigma * T
    volatility = sigma*jnp.sqrt(T)
    r_numb = random.PRNGKey(10)
    BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))

    product = S*jnp.exp(BM)

    payoff = jnp.maximum(product-K,0)

    result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)

    return result

first_derivative = vmap(grad(Simulation))(xi)

I do not know if the way that is implemented the algorithm is the best one to compute the derivative using "AD method"; this algorithm works in this way:

So this is a short explanation of the algorithm, I really appreciate any kind of advice or tips to manage this problem, and compute the derivative with AD method! Thanks in advance

Upvotes: 1

Views: 1782

Answers (1)

jakevdp
jakevdp

Reputation: 86330

It appears that your function maps a vector Rᴺ→Rᴺ. There are two notions of a derivative that make sense in this case: an elementwise derivative (which in JAX you can compute by composing jax.vmap and jax.grad). This will return a derivative vector of length N, where element i contains the derivative of the ith output with respect to the ith input.

Alternatively, you can compute the jacobian matrix (using jax.jacobian) which will return a shape [N, N] matrix, where element i,j contains the derivative of the ith output with respect to the jth input.

The issue you're having is that your function is written assuming a vector input (you ask for the length of xi), which implies you're interested in the jacobian, but you are asking for the elementwise derivative, which requires a scalar-valued function.

So you have two possible ways of solving this, depending on what derivative you're interested in. If you're interested in the jacobian, you can use the function as written and use the jax.jacobian transform:

from jax import jacobian
print(jacobian(Simulation)(xi))
# [[0.6528027 0.        0.        0.        0.       ]
#  [0.        0.6819291 0.        0.        0.       ]
#  [0.        0.        0.7003516 0.        0.       ]
#  [0.        0.        0.        0.7181915 0.       ]
#  [0.        0.        0.        0.        0.7608434]]

Alternatively, if you're interested in the elementwise gradient, you can rewrite your function to be compatible with scalar inputs, and use vmap of grad as you did in your example. Only two lines need to be changed:

def Simulation_scalar(xi):
    K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0

    # S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
    S = jnp.broadcast_to(xi,(number_sim,) + xi.shape).T

    mean = -.5 * sigma * sigma * T
    volatility = sigma*jnp.sqrt(T)
    r_numb = random.PRNGKey(10)
    BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))

    product = S*jnp.exp(BM)

    payoff = jnp.maximum(product-K,0)

    # result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
    result = jnp.average(payoff, axis=-1)*jnp.exp(-q*T)

    return result

print(vmap(grad(Simulation_scalar))(xi))
# [0.6528027 0.6819291 0.7003516 0.7181915 0.7608434]

Upvotes: 4

Related Questions