Jean-Francois T.
Jean-Francois T.

Reputation: 12950

How to use create Python Dataclasses from YAML with default_factory?

I'm trying to create dataclasses directly from YAML description, with some attributes having a default value.

Based on the document of py.yaml, I know that I can register some tags to automatically create classes and loading some YAML.

Combining this with dataclasses, that would give something like:

from dataclasses import dataclass, field

import yaml

@dataclass
class Person(yaml.YAMLObject):
    """A Person, to be loaded from YAML"""

    yaml_loader = yaml.SafeLoader
    yaml_tag = "!person"

    name: str
    parents: list[str]
    age: int = 0

And we can test with the following two YAML files:

!person
name: John Toto
parents:
  - Jacques
  - Lily

and

!person
name: Jacques Toto
parents:
  - Olaf
  - Mary
age: 100

that would give respectively (where we can see that the default value of age is working):

Person(name='John Toto', parents=['Jacques', 'Lily'], age=0)
Person(name='Jacques Toto', parents=['Olaf', 'Mary'], age=100)

However, how do I make this work with the list and dictionaries (or any other attribute requiring the use of field(default_factory=...)?

For example, the following piece of code:

@dataclass
class Person2(yaml.YAMLObject):
    """A Person, to be loaded from YAML"""

    yaml_loader = yaml.SafeLoader
    yaml_tag = "!person2"

    name: str
    parents: list[str] = field(default_factory=list)
    age: int = 0

p2 =  yaml.safe_load(
    """
!person2
name: Jacques Toto
age: 100
"""
)
print(p2)

would give the following excepting when trying to execute the print(p2):

File c:\Python311\Lib\dataclasses.py:240, in _recursive_repr.<locals>.wrapper(self)
    238 repr_running.add(key)
    239 try:
--> 240     result = user_function(self)
    241 finally:
    242     repr_running.discard(key)

File <string>:3, in __repr__(self)

AttributeError: 'Person2' object has no attribute 'parents'

How to populate the parents with a default value using the default_factory method?

Upvotes: 0

Views: 342

Answers (1)

Jean-Francois T.
Jean-Francois T.

Reputation: 12950

By debugging the code of how the class is created by pyyaml, we can find these lines of code in the library:

C:\Python311\Lib\site-packages\yaml\constructor.py

    def construct_yaml_object(self, node, cls):
        data = cls.__new__(cls)
        yield data
        if hasattr(data, '__setstate__'):
            state = self.construct_mapping(node, deep=True)
            data.__setstate__(state)
        else:
            state = self.construct_mapping(node)
            data.__dict__.update(state)

So in normal circonstances, it will just update the dictionary (i.e. the list of attributes) of the class with the values found in YAML.

HOWEVER, if the method __set_state__ is found in the class, it will use this method to create the instance of the class.

Therefore, you could modify your class as follows:

from typing import Any

@dataclass
class Person(yaml.YAMLObject):
    """A Person, to be loaded from YAML"""

    yaml_loader = yaml.SafeLoader
    yaml_tag = "!person"

    name: str
    parents: list[str] = field(default_factory=list)
    age: int = 0

    def __setstate__(self, data: dict[str, Any]):
        """Set default values using default_factory when loaded from YAML"""
        self.__dict__.update(data)

        # pylint: disable=no-member
        for attr, dataclass_field in self.__dataclass_fields__.items():
            default_factory = dataclass_field.default_factory
            if not hasattr(self, attr) and callable(default_factory):
                setattr(self, attr, default_factory())

After updating the dictionary, it will check all fields defined in the dataclass, if the attribute is not defined and this field has a default_factory argument, then it will call this function to create your attribute.

It's doing some black magic, using internal attributes/parameters, so it might not be the cleanest way to do it but at least it works.

BONUS

If you also want to have __post_init__ called after the instance is created from YAML, you can add it in this function with something like:

    def __setstate__(self, data: dict[str, Any]):
        """Set default values using default_factory when loaded from YAML
           and call `__post_init__` if this is defined
        """
        self.__dict__.update(data)

        # pylint: disable=no-member
        for attr, dataclass_field in self.__dataclass_fields__.items():
            default_factory = dataclass_field.default_factory
            if not hasattr(self, attr) and callable(default_factory):
                setattr(self, attr, default_factory())

        # Just for more convenience, we can also allow some __post_init__,
        # for example, it is useful in sub-classes
        post_init = getattr(self, "__post_init__", None)
        if callable(post_init):
            post_init()  # pylint: disable=not-callable

Upvotes: 0

Related Questions