Baron Yugovich
Baron Yugovich

Reputation: 4307

Serialize and deserialize objects from user-defined classes

Suppose I have class hierarchy like this:

class SerializableWidget(object):
# some code

class WidgetA(SerilizableWidget):
# some code

class WidgetB(SerilizableWidget):
# some code

I want to be able to serialize instances of WidgetA and WidgetB (and potentially other widgets) to text files as json. Then, I want to be able to deserialize those, without knowing beforehand their specific class:

some_widget = deserielize_from_file(file_path) # pseudocode, doesn't have to be exactly a method like this

and some_widget needs to be constructed from the precise subclass of SerilizableWidget. How do I do this? What methods exactly do I need to override/implement in each of the classes of my hierarchy?

Assume all fields of the above classes are primitive types. How do I override some __to_json__ and __from_json__ methods, something like that?

Upvotes: 3

Views: 2691

Answers (2)

takanuva15
takanuva15

Reputation: 1678

I really liked @nosklo's answer, but I wanted to customize what the property value was for how the model type got saved, so I extended his code a little to add a sub-annotation.

(I know this isn't directly related to the question, but you can use this to serialize to json too since it produces dict objects. Note that your base class must use the @dataclass annotation to serialize correctly - otherwise you could adjust this code to define the __as_dict__ method like @nosklo's answer)

data.csv:

model_type, prop1
sub1, testfor1
sub2, testfor2

test.py:

import csv
from abc import ABC
from dataclasses import dataclass

from polymorphic import polymorphic


@polymorphic(keyname="model_type")
@dataclass
class BaseModel(ABC):
    prop1: str


@polymorphic.subtype_when_(keyval="sub1")
class SubModel1(BaseModel):
    pass


@polymorphic.subtype_when_(keyval="sub2")
class SubModel2(BaseModel):
    pass


with open('data.csv') as csvfile:
    reader = csv.DictReader(csvfile, skipinitialspace=True)
    for row_data_dict in reader:
        price_req = BaseModel.deserialize(row_data_dict)
        print(price_req, '\n\tre-serialized: ', price_req.serialize())

polymorphic.py:

import dataclasses
import functools
from abc import ABC
from typing import Type


# https://stackoverflow.com/a/51976115
class _Polymorphic:
    def __init__(self, keyname='__class__'):
        self._key = keyname
        self._class_mapping = {}

    def __call__(self, abc: Type[ABC]):
        functools.update_wrapper(self, abc)
        setattr(abc, '_register_subtype', self._register_subtype)
        setattr(abc, 'serialize', lambda self_subclass: self.serialize(self_subclass))
        setattr(abc, 'deserialize', self.deserialize)
        return abc

    def _register_subtype(self, keyval, cls):
        self._class_mapping[keyval] = cls

    def serialize(self, self_subclass) -> dict:
        my_dict = dataclasses.asdict(self_subclass)
        my_dict[self._key] = next(keyval for keyval, cls in self._class_mapping.items() if cls == type(self_subclass))
        return my_dict

    def deserialize(self, data: dict):
        classname = data.pop(self._key, None)
        if classname:
            return self._class_mapping[classname](**data)
        raise ValueError(f'Invalid data: {self._key} was not found or it referred to an unrecognized class')

    @staticmethod
    def subtype_when_(*, keyval: str):
        def register_subtype_for(_cls: _Polymorphic):
            nonlocal keyval
            if not keyval:
                keyval = _cls.__name__
            _cls._register_subtype(keyval, _cls)

            @functools.wraps(_cls)
            def construct_original_subclass(*args, **kwargs):
                return _cls(*args, **kwargs)

            return construct_original_subclass

        return register_subtype_for


polymorphic = _Polymorphic

Sample console output:

SubModel1(prop1='testfor1') 
    re-serialized:  {'prop1': 'testfor1', 'model_type': 'sub1'}
SubModel2(prop1='testfor2') 
    re-serialized:  {'prop1': 'testfor2', 'model_type': 'sub2'}

Upvotes: 0

nosklo
nosklo

Reputation: 222842

You can solve this with many methods. One example is to use the object_hook and default parameters to json.load and json.dump respectively.

All you need is to store the class together with the serialized version of the object, then when loading you have to use a mapping of which class goes with which name.

The example below uses a dispatcher class decorator to store the class name and object when serializing, and look it up later when deserializing. All you need is a _as_dict method on each class to convert the data to a dict:

import json

@dispatcher
class Parent(object):
    def __init__(self, name):
        self.name = name

    def _as_dict(self):
        return {'name': self.name}


@dispatcher
class Child1(Parent):
    def __init__(self, name, n=0):
        super().__init__(name)
        self.n = n

    def _as_dict(self):
        d = super()._as_dict()
        d['n'] = self.n
        return d

@dispatcher
class Child2(Parent):
    def __init__(self, name, k='ok'):
        super().__init__(name)
        self.k = k

    def _as_dict(self):
        d = super()._as_dict()
        d['k'] = self.k
        return d

Now for the tests. First lets create a list with 3 objects of different types.

>>> obj = [Parent('foo'), Child1('bar', 15), Child2('baz', 'works')]

Serializing it will yield the data with the class name in each object:

>>> s = json.dumps(obj, default=dispatcher.encoder_default)
>>> print(s)
[
  {"__class__": "Parent", "name": "foo"},
  {"__class__": "Child1", "name": "bar", "n": 15},
  {"__class__": "Child2", "name": "baz", "k": "works"}
]

And loading it back generates the correct objects:

obj2 = json.loads(s, object_hook=dispatcher.decoder_hook)
print(obj2)
[
  <__main__.Parent object at 0x7fb6cd561cf8>, 
  <__main__.Child1 object at 0x7fb6cd561d68>,
  <__main__.Child2 object at 0x7fb6cd561e10>
]

Finally, here's the implementation of dispatcher:

class _Dispatcher:
    def __init__(self, classname_key='__class__'):
        self._key = classname_key
        self._classes = {} # to keep a reference to the classes used

    def __call__(self, class_): # decorate a class
        self._classes[class_.__name__] = class_
        return class_

    def decoder_hook(self, d):
        classname = d.pop(self._key, None)
        if classname:
            return self._classes[classname](**d)
        return d

    def encoder_default(self, obj):
        d = obj._as_dict()
        d[self._key] = type(obj).__name__
        return d
dispatcher = _Dispatcher()

Upvotes: 4

Related Questions