Chen Lizi
Chen Lizi

Reputation: 103

automatic differentiation of multi-variate functions that may not be close-form

I am trying to compute the a bunch of first order derivates of a multi-variate function that may or may not be close-form. To provide you with more context, I am trying to compute 'Greeks' of options. Options price/value depend upon quite a few things: spot price, strike price, volatility and interest rates and so forth. One of the most commonly used Greek is called delta, which is the change in the price/value of an option with respect to one unit of change in the spot price of a stock. The option's price may not have a close-form/analytic form although here I use some close-form for the sake of simplicity. In reality, the price can be computed using Monte Carlo Simulation. The point is, I need a 'NumPy-friendly' way of computing these first order derivates of some function. This is where I believe a lot of machine learning/deep learning people may help me out. I took some introductory class of machine learning and know that there is a whole world of automatic differentiation, backward propagation and stuff. The library I use here is JAX and it seems to have some issue with 'numpy' as the error message goes like this:

 The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(14793.626)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(14793.626, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

Note that I am working to utilize a 'pricer', a pricing function written by someone else and this pricing function is written in numpy and there is no way it can be written using other libraries. It'd be too much work. I have to 'apply' his pricing function written in numpy.

Btw, I modified the code I saw from some forum. In the original code, the function used is a five-variate function. All I did was simply add one variable called 'divyield', and it just wouldn't work! Thank you very much! I appreciate any help or pointer!

import jax.numpy as np
from jax.scipy.stats import norm
from jax import grad
import numpy as np
import scipy.stats as si
import sympy as sy
from sympy.stats import Normal, cdf
from sympy import init_printing
import jax.numpy as jnp
#import jnp  
init_printing()

class EuropeanCall:

    def __init__(self, inputs):
    
        self.spot_price = inputs[0]
        self.strike_price = inputs[1]
        self.time_to_expiration = inputs[2]
        self.risk_free_rate = inputs[3]
        self.divyield=inputs[4]
        self.volatility = inputs[5]
    
        self.price = EuropeanCall.black_scholes_call_div(self.spot_price, self.strike_price, self.time_to_expiration,
                                             self.risk_free_rate, self.divyield, self.volatility)

        self.gradient_func = grad(EuropeanCall.black_scholes_call_div, (0, 1, 3, 4))
        self.delta, self.vega, self.theta, self.rho = self.gradient_func(inputs[0], inputs[1], inputs[2], inputs[3], 
                                                                     inputs[4],inputs[5])
        self.theta /= -365
        self.vega /= 100
        self.rho /= 100



    @staticmethod
    def black_scholes_call_div(S, K, T, r, q, sigma):
    

#S: spot price
#K: strike price
#T: time to maturity
#r: interest rate
#q: rate of continuous dividend paying asset 
#sigma: volatility of underlying asset
#r=r+cds
        d1 = (np.log(S / K) + (r - q + 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))
        d2 = (np.log(S / K) + (r - q - 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))

        call = (S * np.exp(-q * T) * si.norm.cdf(d1, 0.0, 1.0) - K * np.exp(-r * T) * si.norm.cdf(d2, 0.0, 1.0))

        return call

class EuropeanPut:

    def __init__(self, inputs):
    
        self.spot_price = inputs[0]
        self.strike_price = inputs[1]
        self.time_to_expiration = inputs[2]
        self.short_risk_free_rate = inputs[3]
        self.divyield=inputs[4]
        self.volatility = inputs[5]
    
        self.price = EuropeanPut.black_scholes_put_div(self.spot_price,  self.strike_price, self.time_to_expiration, 
                                            self.short_risk_free_rate,self.divyield,self.volatility)

        self.gradient_func = grad(EuropeanPut.black_scholes_put_div, (0,1,3,4))
        self.delta, self.vega, self.theta, self.rho = self.gradient_func(inputs[0], inputs[1], inputs[2], inputs[3], 
                                                                     inputs[4],inputs[5])
        self.theta /= -365
        self.vega /= 100
        self.rho /= 100



    @staticmethod
    def black_scholes_put_div(S, K, T, r, q, sigma):

#S: spot price
#K: strike price
#T: time to maturity
#r: interest rate
#q: rate of continuous dividend paying asset 
#sigma: volatility of underlying asset
#r=r+cds
        d1 = (np.log(S / K) + (r - q + 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))
        d2 = (np.log(S / K) + (r - q - 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))

        put = (K * np.exp(-r * T) * si.norm.cdf(-d2, 0.0, 1.0) - S * np.exp(-q * T) * si.norm.cdf(-d1, 0.0, 1.0))

        return put

              #spot_price,vol, K,T,r
inputs = np.array([3109.62, .2102, 27/365,.017,0.02,0.25])
ec = EuropeanCall(inputs.astype('float'))
print(ec.delta, ec.vega, ec.theta, ec.rho)

Upvotes: 1

Views: 2512

Answers (1)

jakevdp
jakevdp

Reputation: 86513

The error message tells you what you need to do:

You might want to check that you are using jnp together with import jax.numpy as jnp rather than using np via import numpy as np

JAX cannot differentiate numpy functions, but it can differentiate jax.numpy functions. So replace np.log, np.sqrt, np.exp, etc. with jnp.log, jnp.sqrt, jnp.exp, etc., and similarly replace scipy calls with jax.scipy calls. Once all operations are implemented via JAX, you should be able to compute gradients with JAX.

If you're using a third-party module that is implemented in numpy that you cannot rewrite with JAX, then you will not be able to directly use JAX transforms, including auto-differentiation.

Upvotes: 2

Related Questions