Reputation: 704
I have this code:
from dataclasses import dataclass
from typing import List
@dataclass
class Position:
name: str
lon: float
lat: float
@dataclass
class Section:
positions: List[Position]
pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)
sec = Section([pos1, pos2 , pos3])
print(sec.positions)
How can I create additional attributes in the dataclass Section
so they would be a list of the attribute of its subclass Position
?
In my example, I would like that the section object also returns:
sec.name = ['a', 'b', 'c'] #[pos1.name,pos2.name,pos3.name]
sec.lon = [52, 46, 45] #[pos1.lon,pos2.lon,pos3.lon]
sec.lat = [10, -10, -10] #[pos1.lat,pos2.lat,pos3.lat]
I tried to define the dataclass as:
@dataclass
class Section:
positions: List[Position]
names : List[Position.name]
But it is not working because name is not an attribute of position. I can define the object attributed later in the code (e.g. by doing secs.name = [x.name for x in section.positions]
). But it would be nicer if it can be done at the dataclass definition level.
After posting this question I found a beginning of answer (https://stackoverflow.com/a/65222586/13890678).
But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?
Upvotes: 6
Views: 5298
Reputation: 20157
The way I understood you, you'd like to declare dataclasses that are flat data containers (like Position
), which are nested into a container of another dataclass (like Section
). The outer dataclass should then be able to access a list of all the attributes of its inner dataclass(es) through simple name access.
We can implement this kind of functionality (calling it, for example, introspect
) on top of how a regular dataclass works, and can enable it on demand, similar to the already existing flags:
from dataclasses import is_dataclass, fields, dataclass as dc
# existing dataclass siganture, plus "instrospection" keyword
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
unsafe_hash=False, frozen=False, introspect=False):
def wrap(cls):
# run original dataclass decorator
dc(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen)
# add our custom "introspect" logic on top
if introspect:
for field in fields(cls):
# only consider nested dataclass in containers
try:
name = field.type._name
except AttributeError:
continue
if name not in ("List", "Set", "Tuple"):
continue
contained_dc = field.type.__args__[0]
if not is_dataclass(contained_dc):
continue
# once we got them, add their fields as properties
for dc_field in fields(contained_dc):
# if there are name-clashes, use f"{field.name}_{dc_field.name}" instead
property_name = dc_field.name
# bind variables explicitly to avoid funny bugs
def magic_property(self, field=field, dc_field=dc_field):
return [getattr(attr, dc_field.name) for attr in getattr(self, field.name)]
# here is where the magic happens
setattr(
cls,
property_name,
property(magic_property)
)
return cls
# Handle being called with or without parens
if _cls is None:
return wrap
return wrap(_cls)
The resulting dataclass
-function can now be used in the following way:
# regular dataclass
@dataclass
class Position:
name: str
lon: float
lat: float
# this one will introspect its fields and try to add magic properties
@dataclass(introspect=True)
class Section:
positions: List[Position]
And that's it. The properties get added during class construction, and will even update accordingly if any of the objects changes during its lifetime:
>>> p_1 = Position("1", 1.0, 1.1)
>>> p_2 = Position("2", 2.0, 2.1)
>>> p_3 = Position("3", 3.0, 3.1)
>>> section = Section([p_1 , p_2, p_3])
>>> section.name
['1', '2', '3']
>>> section.lon
[1.0, 2.0, 3.0]
>>> p_1.lon = 5.0
>>> section.lon
[5.0, 2.0, 3.0]
Upvotes: 1
Reputation: 18106
You could create a new field after __init__
was called:
from dataclasses import dataclass, field, fields
from typing import List
@dataclass
class Position:
name: str
lon: float
lat: float
@dataclass
class Section:
positions: List[Position]
_pos: dict = field(init=False, repr=False)
def __post_init__(self):
# create _pos after init is done, read only!
Section._pos = property(Section._get_positions)
def _get_positions(self):
_pos = {}
# iterate over all fields and add to _pos
for field in [f.name for f in fields(self.positions[0])]:
if field not in _pos:
_pos[field] = []
for p in self.positions:
_pos[field].append(getattr(p, field))
return _pos
pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)
sec = Section([pos1, pos2, pos3])
print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])
Out:
[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
Edit:
In case you just need it more generic, you could overwrite __getattr__
:
from dataclasses import dataclass, field, fields
from typing import List
@dataclass
class Position:
name: str
lon: float
lat: float
@dataclass
class Section:
positions: List[Position]
def __getattr__(self, keyName):
for f in fields(self.positions[0]):
if f"{f.name}s" == keyName:
return [getattr(x, f.name) for x in self.positions]
# Error handling here: Return empty list, raise AttributeError, ...
pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)
sec = Section([pos1, pos2, pos3])
print(sec.names)
print(sec.lons)
print(sec.lats)
Out:
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
Upvotes: 6
Reputation: 704
After some more thinking I thought an alternative solution using methods:
from dataclasses import dataclass
from typing import List
@dataclass
class Position:
name: str
lon: float
lat: float
@dataclass
class Section:
positions: List[Position]
def names(self):
return [x.name for x in self.positions]
def lons(self):
return [x.lon for x in self.positions]
def lats(self):
return [x.lat for x in self.positions]
pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)
sec = Section([pos1, pos2 , pos3])
print(sec.positions)
print(sec.names())
print(sec.lons())
print(sec.lats())
But I was wondering if there was not a more generic/"automatic" way of defining the Section
methods : .names(), .lons(), .lats(), ...
?
So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions
object attributes?
Upvotes: 1