lhoupert
lhoupert

Reputation: 704

How to define a dataclass so each of its attributes is the list of its subclass attributes?

I have this code:

from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)

How can I create additional attributes in the dataclass Section so they would be a list of the attribute of its subclass Position?

In my example, I would like that the section object also returns:

sec.name = ['a', 'b', 'c']   #[pos1.name,pos2.name,pos3.name]
sec.lon = [52, 46, 45]       #[pos1.lon,pos2.lon,pos3.lon]
sec.lat = [10, -10, -10]     #[pos1.lat,pos2.lat,pos3.lat]

I tried to define the dataclass as:

@dataclass
class Section:
    positions: List[Position]
    names :  List[Position.name]

But it is not working because name is not an attribute of position. I can define the object attributed later in the code (e.g. by doing secs.name = [x.name for x in section.positions]). But it would be nicer if it can be done at the dataclass definition level.

After posting this question I found a beginning of answer (https://stackoverflow.com/a/65222586/13890678).

But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?

Upvotes: 6

Views: 5298

Answers (3)

Arne
Arne

Reputation: 20157

The way I understood you, you'd like to declare dataclasses that are flat data containers (like Position), which are nested into a container of another dataclass (like Section). The outer dataclass should then be able to access a list of all the attributes of its inner dataclass(es) through simple name access.

We can implement this kind of functionality (calling it, for example, introspect) on top of how a regular dataclass works, and can enable it on demand, similar to the already existing flags:

from dataclasses import is_dataclass, fields, dataclass as dc

# existing dataclass siganture, plus "instrospection" keyword
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
              unsafe_hash=False, frozen=False, introspect=False):

    def wrap(cls):
        # run original dataclass decorator
        dc(cls, init=init, repr=repr, eq=eq, order=order,
           unsafe_hash=unsafe_hash, frozen=frozen)

        # add our custom "introspect" logic on top
        if introspect:
            for field in fields(cls):
                # only consider nested dataclass in containers
                try:
                    name = field.type._name
                except AttributeError:
                    continue
                if name not in ("List", "Set", "Tuple"):
                    continue
                contained_dc = field.type.__args__[0]
                if not is_dataclass(contained_dc):
                    continue
                # once we got them, add their fields as properties
                for dc_field in fields(contained_dc):
                    # if there are name-clashes, use f"{field.name}_{dc_field.name}" instead
                    property_name = dc_field.name
                    # bind variables explicitly to avoid funny bugs
                    def magic_property(self, field=field, dc_field=dc_field):
                        return [getattr(attr, dc_field.name) for attr in getattr(self, field.name)]
                    # here is where the magic happens
                    setattr(
                        cls,
                        property_name,
                        property(magic_property)
                    )
        return cls

    # Handle being called with or without parens
    if _cls is None:
        return wrap
    return wrap(_cls)

The resulting dataclass-function can now be used in the following way:

# regular dataclass
@dataclass
class Position:
    name: str
    lon: float
    lat: float
    
# this one will introspect its fields and try to add magic properties
@dataclass(introspect=True)
class Section:
    positions: List[Position]

And that's it. The properties get added during class construction, and will even update accordingly if any of the objects changes during its lifetime:

>>> p_1 = Position("1", 1.0, 1.1)
>>> p_2 = Position("2", 2.0, 2.1)
>>> p_3 = Position("3", 3.0, 3.1)
>>> section = Section([p_1 , p_2, p_3])
>>> section.name
['1', '2', '3']
>>> section.lon
[1.0, 2.0, 3.0]
>>> p_1.lon = 5.0
>>> section.lon
[5.0, 2.0, 3.0]

Upvotes: 1

Maurice Meyer
Maurice Meyer

Reputation: 18106

You could create a new field after __init__ was called:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]
    _pos: dict = field(init=False, repr=False)

    def __post_init__(self):
        # create _pos after init is done, read only!
        Section._pos = property(Section._get_positions)

    def _get_positions(self):
        _pos = {}

        # iterate over all fields and add to _pos
        for field in [f.name for f in fields(self.positions[0])]:
            if field not in _pos:
                _pos[field] = []

            for p in self.positions:
                _pos[field].append(getattr(p, field))
        return _pos


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])

Out:

[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]

Edit:

In case you just need it more generic, you could overwrite __getattr__:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]

    def __getattr__(self, keyName):
        for f in fields(self.positions[0]):
            if f"{f.name}s" == keyName:
                return [getattr(x, f.name) for x in self.positions]
        # Error handling here: Return empty list, raise AttributeError, ...

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.names)
print(sec.lons)
print(sec.lats)

Out:

['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]

Upvotes: 6

lhoupert
lhoupert

Reputation: 704

After some more thinking I thought an alternative solution using methods:


from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

    def names(self):
        return [x.name for x in self.positions]

    def lons(self):
        return [x.lon for x in self.positions]

    def lats(self):
        return [x.lat for x in self.positions]


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)
print(sec.names())
print(sec.lons())
print(sec.lats())

But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?

Upvotes: 1

Related Questions