user3613025
user3613025

Reputation: 383

How to overwrite attributes in OOP inheritance with Python dataclass

Currently, I've some code that looks like this, with the irrelevant methods removed.

import math
import numpy as np
from decimal import Decimal
from dataclasses import dataclass, field
from typing import Optional, List

@dataclass
class A:
    S0: int
    K: int
    r: float = 0.05
    T: int = 1
    N: int = 2
    StockTrees: List[float] = field(init=False, default_factory=list)
    pu: Optional[float] = 0
    pd: Optional[float] = 0
    div: Optional[float] = 0
    sigma: Optional[float] = 0
    is_put: Optional[bool] = field(default=False)
    is_american: Optional[bool] = field(default=False)
    is_call: Optional[bool] = field(init=False)
    is_european: Optional[bool] = field(init=False)
    
    def __post_init__(self):
        self.is_call = not self.is_put
        self.is_european = not self.is_american
        
    @property
    def dt(self):
        return self.T/float(self.N)
    
    @property
    def df(self):
        return math.exp(-(self.r - self.div) * self.dt)

@dataclass
class B(A):

    u: float = field(init=False)
    d: float = field(init=False)
    qu: float = field(init=False)
    qd: float = field(init=False)
    
    def __post_init__(self):
        super().__post_init__()
        self.u = 1 + self.pu
        self.d = 1 - self.pd
        self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
        self.qd = 1 - self.qu
    
    
@dataclass
class C(B):
    def __post_init__(self):
        super().__post_init__()
        self.u = math.exp(self.sigma * math.sqrt(self.dt))
        self.d = 1/self.u
        self.qu = (math.exp((self.r - self.div)*self.dt) - self.d)/(self.u - self.d)
        self.qd = 1 - self.qu

Basically, I have a class A where it defines some attributes that all of its child classes will share, so it's only really meant to be initialised via the instantiation of its child classes and its attributes are to be inherited by its child classes. The child class B is meant to be a process which does some calculation which is inherited by C which does a variation of the same calculation. C basically inherits all the methods from B and its only difference is that its calculation of self.u and self.d are different.

One can run the code by either using B calculation which requires arguments pu and pd or C calculation which requires argument sigma, as below

if __name__ == "__main__":
    
    am_option = B(50, 52, r=0.05, T=2, N=2, pu=0.2, pd=0.2, is_put=True, is_american=True)
    print(f"{am_option.sigma = }")
    print(f"{am_option.pu = }")
    print(f"{am_option.pd = }")
    print(f"{am_option.qu = }")
    print(f"{am_option.qd = }")
    
    eu_option2 = C(50, 52, r=0.05, T=2, N=2, sigma=0.3, is_put=True)
    print(f"{am_option.sigma = }")
    print(f"{am_option.pu = }")
    print(f"{am_option.pd = }")
    print(f"{am_option.qu = }")
    print(f"{am_option.qd = }")

which gives the output

am_option.pu = 0.2
am_option.pd = 0.2
am_option.qu = 0.6281777409400603
am_option.qd = 0.3718222590599397
Traceback (most recent call last):
  File "/home/dazza/option_pricer/test.py", line 136, in <module>
    eu_option2 = C(50, 52, r=0.05, T=2, N=2, sigma=0.3, is_put=True)
  File "<string>", line 15, in __init__
  File "/home/dazza/option_pricer/test.py", line 109, in __post_init__
    super().__post_init__()
  File "/home/dazza/option_pricer/test.py", line 55, in __post_init__
    self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
ZeroDivisionError: float division by zero

So instantiating B works fine since it successfully calculated the values pu,pd,qu and qd. However, my problem comes when the instantiation of C is unable to calculate qu since pu and pd are zeros by default, making it divide by 0.

My question: How can I fix this so that C inherits all the attributes initialisation (including __post_init__) of A and all methods of B, and at the same time have its calculation of self.u = math.exp(self.sigma * math.sqrt(self.dt)) and self.d = 1/self.u overwriting self.u = 1 + self.pu and self.d = 1 - self.pd of B, as well as keeping self.qu and self.qd the same?(they're the same for B and C)

Upvotes: 0

Views: 3095

Answers (2)

Elan-R
Elan-R

Reputation: 554

Python supports multiple inheritance. You can inherit from A before B, which means any overlapping methods will be taken from A (such as __post_init__). Any code you write in class C will overwrite what's inherited from A and B. If you need to have more control over which methods come from which class, you can always define the method in C and make a function call to A or B (like A.dt(self)).

class C(A, B):
    ...

ANOTHER EDIT: I just saw that A initializes some stuff you want in C. Because C's parent is now A (if you used my code above), you can add back in the super().__post_init__() line to C's __post_init__ so that it calls A's __post_init__. If this doesn't work, you can always just put A.__post_init__(self) in the __post_init__ of C.

Upvotes: 0

chepner
chepner

Reputation: 531430

Define another method to initialize u and d, so that you can override that part of B without overriding how qu and qd are defined.

@dataclass
class B(A):

    u: float = field(init=False)
    d: float = field(init=False)
    qu: float = field(init=False)
    qd: float = field(init=False)
    
    def __post_init__(self):
        super().__post_init__()
        self._define_u_and_d()
        self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
        self.qd = 1 - self.qu

    def _define_u_and_d(self):
        self.u = 1 + self.pu
        self.d = 1 - self.pd



@dataclass
class C(B):
    def _define_u_and_d(self):
        self.u = math.exp(self.sigma * math.sqrt(self.dt))
        self.d = 1/self.u

Upvotes: 2

Related Questions