jss367
jss367

Reputation: 5401

Define an attribute with a dataclass method in Python

I have a class for bounding box coordinates that I would like to convert to a dataclass, but I cannot figure out how to set attributes using a class method like I would in a normal class. Here is the normal class:

class BBoxCoords:
    """Class for bounding box coordinates"""
    def __init__(self, top_left_x: float, top_left_y: float, bottom_right_x: float, bottom_right_y: float):
        self.top_left_x = top_left_x
        self.top_left_y = top_left_y
        self.bottom_right_x = bottom_right_x
        self.bottom_right_y = bottom_right_y
        self.height = self.get_height()

    def get_height(self) -> float:
        return self.bottom_right_y - self.top_left_y

and here is what I want it to do:

bb = BBoxCoords(1, 1, 5, 5)
bb.height
> 4

This is exactly what I want. I tried to do the same thing with a dataclass

from dataclasses import dataclass    

@dataclass
class BBoxCoords:
    """Class for bounding box coordinates"""
top_left_x: float
top_left_y: float
bottom_right_x: float
bottom_right_y: float
height = self.get_height()

def get_height(self) -> float:
    return self.bottom_right_y - self.top_left_y

but self isn't defined when I try to use it, so I get a NameError. What's the correct way of doing this with a dataclass? I know I could do

bb = BBoxCoords(1, 1, 5, 5)
bb.get_height()
> 4

but I would rather call an attribute than a method.

Upvotes: 3

Views: 4122

Answers (1)

juanpa.arrivillaga
juanpa.arrivillaga

Reputation: 96257

For this sort of thing, you need __post_init__, which will run after __init__. Also, make sure height isn't set in __init__, so:

from dataclasses import dataclass, field   

@dataclass
class BBoxCoords:
    """Class for bounding box coordinates"""
    top_left_x: float
    top_left_y: float
    bottom_right_x: float
    bottom_right_y: float
    height: float = field(init=False)

    def __post_init__(self):
        self.height = self.get_height()

    def get_height(self) -> float:
        return self.bottom_right_y - self.top_left_y

In action:

In [1]: from dataclasses import dataclass, field
   ...:
   ...: @dataclass
   ...: class BBoxCoords:
   ...:     """Class for bounding box coordinates"""
   ...:     top_left_x: float
   ...:     top_left_y: float
   ...:     bottom_right_x: float
   ...:     bottom_right_y: float
   ...:     height: float = field(init=False)
   ...:
   ...:     def __post_init__(self):
   ...:         self.height = self.get_height()
   ...:
   ...:     def get_height(self) -> float:
   ...:         return self.bottom_right_y - self.top_left_y
   ...:

In [2]: BBoxCoords(1, 1, 5, 5)
Out[2]: BBoxCoords(top_left_x=1, top_left_y=1, bottom_right_x=5, bottom_right_y=5, height=4)

Upvotes: 7

Related Questions