user10203644
user10203644

Reputation: 41

Custom type serialization inside a data class

I have recently discovered python dataclasses, and I am trying to implement a dataclass, which has some 3rd party library type as a field. This member has its own serialisation methods, and special handling should be possible to define in pre_load/post_load and pre_dump/post_dump decorators. I managed to implement serialisation but cannot figure out how to deserialise from a dict:

from marshmallow_dataclass import dataclass
from marshmallow import pre_load, post_load, post_dump

# third party class has the following methods
class Message:
    def __init__(self, text: str):
        # the actual implementation of the object is much more complex
        self.text = text

    def to_json(self) -> dict:
        return {'text': self.text}

    @classmethod
    def from_json(cls, json: dict):
        return cls(json['text'])


@dataclass
class MyClass:
    message: Message

    @pre_load
    def preload(self, data, **kwargs):
        data['message'] = Message.from_json(data['message'])
        return data

    @post_load(pass_original=True)
    def postload(self, data, obj, **kwargs):
        # to leave data unchanged after construction
        data['message'] = obj.message.to_json()

    @post_dump(pass_original=True)
    def postdump(self, data, obj, **kwargs):
        data['message'] = obj.message.to_json()
        return data


msg = Message('Some text')

# serialization works fine
my_class = MyClass(msg)
json_dict = MyClass.Schema().dump(my_class)  # = {'message': {'text': 'Some text'}}

# throws: marshmallow.exceptions.ValidationError: {'message': {'_schema': ['Invalid input type.']}}
my_class = MyClass.Schema().load(json_dict)

Upvotes: 2

Views: 1306

Answers (1)

user10203644
user10203644

Reputation: 41

It turns out, there is an elegant solution to this, and it is to define a new field:

from marshmallow_dataclass import dataclass, NewType
from marshmallow import fields


class MessageField(fields.Field):
    def _serialize(self, value: Message, attr, obj, **kwargs):
        return value.to_json()

    def _deserialize(self, value: dict, attr, data, **kwargs):
        return Message.from_json(value)


MessageType = NewType('Message', Message, MessageField)

@dataclass
class MyClass:
    message: MessageType

Upvotes: 1

Related Questions