Markus
Markus

Reputation: 175

How to overwrite Python Dataclass 'asdict' method

I have a dataclass, which looks like this:

@dataclass
class myClass:
   id: str
   mode: str
   value: float

This results in:

dataclasses.asdict(myClass)
{"id": id, "mode": mode, "value": value}

But what I want is

{id:{"mode": mode, "value": value}}

I thought I could achive this by adding a to_dict method to my dataclass, which returns the desired dict, but that didn't work.

How could I get my desired result?

Upvotes: 15

Views: 8513

Answers (2)

OrderFromChaos
OrderFromChaos

Reputation: 170

Here is a solution for nested dataclasses.

The relevant underlying implementation is:

# from lib/python3.11/dataclasses.py:1280
def _asdict_inner(obj, dict_factory):
    if _is_dataclass_instance(obj):
        result = []
        for f in fields(obj):
            value = _asdict_inner(getattr(obj, f.name), dict_factory)
            result.append((f.name, value))
        return dict_factory(result)
    # ... (other types below here)

As you can see, the relevant override for dataclasses is fields. Fields calls _FIELDS which refers to self.__dataclass_fields__. Unfortunately this is not easy to overwrite since a lot of other dataclass functions rely on getting the "ground truth" of the base fields from this function.

When an implementation is tightly coupled like this, the easiest fix is just overwriting the library's interface with your desired behavior. Here is how I solved this, making a new override method __dict_factory_override__ (and filtering for dataclass values != 0):

src/classes.py

import math
from dataclasses import dataclass

@dataclass
class TwoLayersDeep:
    f: int = 0

@dataclass
class OneLayerDeep:
    c: TwoLayersDeep
    d: float = 1.0
    e: float = 0.0

    def __dict_factory_override__(self):
        normal_dict = {k: getattr(self, k) for k in self.__dataclass_fields__}
        return {
            k: v for k, v in normal_dict.items()
            if not isinstance(v, float) or not math.isclose(v, 0)
            # ^^ here I specify the rule for which items are included
            # for printing/formatiing
        }

@dataclass
class TopLevelClass:
    a: OneLayerDeep
    b: int = 0

src/dataclasses_override.py

import dataclasses

# If dataclass has __dict_factory_override__, use that instead of dict_factory
_asdict_inner_actual = dataclasses._asdict_inner
def _asdict_inner(obj, dict_factory):

    # if override exists, intercept and return that instead
    if dataclasses._is_dataclass_instance(obj):
        if getattr(obj, '__dict_factory_override__', None):
            user_dict = obj.__dict_factory_override__()

            for k, v in user_dict.items(): # in case of further nesting
                if dataclasses._is_dataclass_instance(v):
                    user_dict[k] = _asdict_inner(v, dict_factory)
            return user_dict

    # otherwise do original behavior
    return _asdict_inner_actual(obj, dict_factory)
dataclasses._asdict_inner = _asdict_inner
asdict = dataclasses.asdict

main.py

from src.classes import OneLayerDeep, TopLevelClass, TwoLayersDeep
from src.dataclasses_override import asdict

print(asdict(
    TopLevelClass(
        a=OneLayerDeep(
            c=TwoLayersDeep()
        )
    )
))
# {'a': {'c': {'f': 0}, 'd': 1.0}, 'b': 0}
# As expected, e=0.0 is not printed as it isclose(0)

Upvotes: 0

Evgeniy_Burdin
Evgeniy_Burdin

Reputation: 703

from dataclasses import dataclass, asdict


@dataclass
class myClass:
    id: str
    mode: str
    value: float


def my_dict(data):
    return {
        data[0][1]: {
            field: value for field, value in data[1:]
        }
    }


instance = myClass("123", "read", 1.23)

data = {"123": {"mode": "read", "value":  1.23}}

assert asdict(instance, dict_factory=my_dict) == data

Upvotes: 7

Related Questions