Reputation: 152
I am trying to perform a Montecarlo Simulation on a call and after that compute in Python its first derivative with respect to the underlying asset, but it still does not works
from jax import random
from jax import jit, grad, vmap
import jax.numpy as jnp
xi = jnp.linspace(1,1.2,5)
def Simulation(xi):
K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
mean = -.5 * sigma * sigma * T
volatility = sigma*jnp.sqrt(T)
r_numb = random.PRNGKey(10)
BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))
product = S*jnp.exp(BM)
payoff = jnp.maximum(product-K,0)
result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
return result
first_derivative = vmap(grad(Simulation))(xi)
I do not know if the way that is implemented the algorithm is the best one to compute the derivative using "AD method"; this algorithm works in this way:
S = Simulate a matrix containing all the underlying; for each row I have each underlying generated with the "xi = jnp.linspace", and inside each row of the matrix I have the same value for a number of times equal to "number_sim"
product = After generating the BM ( vector containing normal number ) I need to multiply each element of BM (with exp) with each element of each row of S
So this is a short explanation of the algorithm, I really appreciate any kind of advice or tips to manage this problem, and compute the derivative with AD method! Thanks in advance
Upvotes: 1
Views: 1782
Reputation: 86330
It appears that your function maps a vector Rᴺ→Rᴺ. There are two notions of a derivative that make sense in this case: an elementwise derivative (which in JAX you can compute by composing jax.vmap
and jax.grad
). This will return a derivative vector of length N, where element i contains the derivative of the ith output with respect to the ith input.
Alternatively, you can compute the jacobian matrix (using jax.jacobian
) which will return a shape [N, N]
matrix, where element i,j contains the derivative of the ith output with respect to the jth input.
The issue you're having is that your function is written assuming a vector input (you ask for the length of xi
), which implies you're interested in the jacobian, but you are asking for the elementwise derivative, which requires a scalar-valued function.
So you have two possible ways of solving this, depending on what derivative you're interested in. If you're interested in the jacobian, you can use the function as written and use the jax.jacobian
transform:
from jax import jacobian
print(jacobian(Simulation)(xi))
# [[0.6528027 0. 0. 0. 0. ]
# [0. 0.6819291 0. 0. 0. ]
# [0. 0. 0.7003516 0. 0. ]
# [0. 0. 0. 0.7181915 0. ]
# [0. 0. 0. 0. 0.7608434]]
Alternatively, if you're interested in the elementwise gradient, you can rewrite your function to be compatible with scalar inputs, and use vmap of grad as you did in your example. Only two lines need to be changed:
def Simulation_scalar(xi):
K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
# S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
S = jnp.broadcast_to(xi,(number_sim,) + xi.shape).T
mean = -.5 * sigma * sigma * T
volatility = sigma*jnp.sqrt(T)
r_numb = random.PRNGKey(10)
BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))
product = S*jnp.exp(BM)
payoff = jnp.maximum(product-K,0)
# result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
result = jnp.average(payoff, axis=-1)*jnp.exp(-q*T)
return result
print(vmap(grad(Simulation_scalar))(xi))
# [0.6528027 0.6819291 0.7003516 0.7181915 0.7608434]
Upvotes: 4