twhughes
twhughes

Reputation: 534

Distinguishing between Pydantic Models with same fields

I'm using Pydantic to define hierarchical data in which there are models with identical attributes.

However, when I save and load these models, Pydantic can no longer distinguish which model was used and picks the first one in the field type annotation.

I understand that this is expected behavior based on the documentation. However, the class type information is important to my application.

What is the recommended way to distinguish between different classes in Pydantic? One hack is to simply add an extraneous field to one of the models, but I'd like to find a more elegant solution.

See the simplified example below: container is initialized with data of type DataB, but after exporting and loading, the new container has data of type DataA as it's the first element in the type declaration of container.data.

Thanks for your help!

from abc import ABC
from pydantic import BaseModel #pydantic 1.8.2
from typing import Union

class Data(BaseModel, ABC):
    """ base class for a Member """
    number: float

class DataA(Data):
    """ A type of Data"""
    pass

class DataB(Data):
    """ Another type of Data """
    pass

class Container(BaseModel):
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]

# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)

# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)

# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataA

Upvotes: 11

Views: 13755

Answers (3)

Wizard.Ritvik
Wizard.Ritvik

Reputation: 11612

Just wanted to take the opportunity to list another possible alternative here to pydantic - which already supports this use case very well, as per below answer.

I am the creator and maintainer of a relatively newer and lesser-known JSON serialization library, the Dataclass Wizard - which relies on the Python dataclasses module to perform its magic. As of the latest version, 0.14.0, the dataclass-wizard now supports dataclasses within Union types. Previously, it did not support dataclasses within Union types at all, which was kind of a glaring omission, and something on my "to-do" list of things to (eventually) add support for.

As of the latest, it should now support defining dataclasses within Union types. The reason it did not generally work before, is because the data being de-serialized is often a JSON object, which only knows simple types such as arrays and dictionaries, for example. A dict type would not otherwise match any of the Union[Data1, Data2] types, even if the object had all the correct dataclass fields as keys. This is simply because it doesn't compare the dict object against each of the dataclass fields in the Union types, though that might change in a future release.

So in any case, here is a simple example to demonstrate the usage of dataclasses in Union types, using a class inheritance model with the JSONWizard mixin class:

With Class Inheritance
from abc import ABC
from dataclasses import dataclass
from typing import Union

from dataclass_wizard import JSONWizard


@dataclass
class Data(ABC):
    """ base class for a Member """
    number: float


class DataA(Data, JSONWizard):
    """ A type of Data"""

    class _(JSONWizard.Meta):
        """
        This defines a custom tag that uniquely identifies the dataclass.
        """
        tag = 'A'


class DataB(Data, JSONWizard):
    """ Another type of Data """

    class _(JSONWizard.Meta):
        """
        This defines a custom tag that uniquely identifies the dataclass.
        """
        tag = 'B'


@dataclass
class Container(JSONWizard):
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]

The usage is shown below, and is again pretty straightforward. It relies on a special __tag__ key set in a dictionary or JSON object to marshal it into the correct dataclass, based on the Meta.tag value for that class, that we have set up above.

print('== Load with DataA ==')

input_dict = {
    'data': {
        'number': '1.0',
        '__tag__': 'A'
    }
}

# De-serialize the `dict` object to a `Container` instance.
container = Container.from_dict(input_dict)

print(repr(container))
# prints:
#   Container(data=DataA(number=1.0))

# Show the prettified JSON representation of the instance.
print(container)

# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataA

print()

print('== Load with DataB ==')

# initialize container with DataB
data_b = DataB(number=2.0)
container = Container(data=data_b)

print(repr(container))
# prints:
#   Container(data=DataB(number=2.0))

# Show the prettified JSON representation of the instance.
print(container)

# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB

# Assert we end up with the same instance when serializing and de-serializing
# our data.
string = container.to_json()
assert container == Container.from_json(string)
Without Class Inheritance

Here is the same example as above, but with relying solely on dataclasses, without using any special class inheritance model:

from abc import ABC
from dataclasses import dataclass
from typing import Union

from dataclass_wizard import asdict, fromdict, LoadMeta


@dataclass
class Data(ABC):
    """ base class for a Member """
    number: float


class DataA(Data):
    """ A type of Data"""


class DataB(Data):
    """ Another type of Data """


@dataclass
class Container:
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]


# Setup tags for the dataclasses. This can be passed into either
# `LoadMeta` or `DumpMeta`.
#
# Note that I'm not a fan of this syntax either, so it might change. I was
# thinking of something more explicit, like `LoadMeta(...).bind_to(class)`
LoadMeta(DataA, tag='A')
LoadMeta(DataB, tag='B')

# The rest is the same as before.

# initialize container with DataB
data = DataB(number=2.0)
container = Container(data=data)

print(repr(container))
# prints:
#   Container(data=DataB(number=2.0))

# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB

# Assert we end up with the same data when serializing and de-serializing.
out_dict = asdict(container)
assert container == fromdict(Container, out_dict)

Upvotes: 2

twhughes
twhughes

Reputation: 534

I'm trying to hack something together in the meantime using custom validators. Basically the class decorator adds a class_name: str field, which is added to the json string. The validator then looks up the correct subclass based on its value.

def register_distinct_subclasses(fields: tuple):
    """ fields is tuple of subclasses that we want to be registered as distinct """

    field_map = {field.__name__: field for field in fields}

    def _register_distinct_subclasses(cls):
        """ cls is the superclass of fields, which we add a new validator to """

        orig_init = cls.__init__

        class _class:
            class_name: str

            def __init__(self, **kwargs):
                class_name = type(self).__name__
                kwargs["class_name"] = class_name
                orig_init(**kwargs)

            @classmethod
            def __get_validators__(cls):
                yield cls.validate

            @classmethod
            def validate(cls, v):
                if isinstance(v, dict):
                    class_name = v.get("class_name")
                    json_string = json.dumps(v)
                else:
                    class_name = v.class_name
                    json_string = v.json()
                cls_type = field_map[class_name]
                return cls_type.parse_raw(json_string)

        return _class
    return _register_distinct_subclasses

which is called as follows

Data = register_distinct_subclasses((DataA, DataB))(Data)

Upvotes: 0

alex_noname
alex_noname

Reputation: 32053

As correctly noted in the comments, without storing additional information models cannot be distinguished when parsing.

As of today (pydantic v1.8.2), the most canonical way to distinguish models when parsing in a Union (in case of ambiguity) is to explicitly add a type specifier Literal. It will look like this:

from abc import ABC
from pydantic import BaseModel
from typing import Union, Literal

class Data(BaseModel, ABC):
    """ base class for a Member """
    number: float


class DataA(Data):
    """ A type of Data"""
    tag: Literal['A'] = 'A'


class DataB(Data):
    """ Another type of Data """
    tag: Literal['B'] = 'B'


class Container(BaseModel):
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]


# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)

# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)


# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataB

This method can be automated, but you can use it at your own responsibility, since it breaks static typing and uses objects that may change in future versions:

from pydantic.fields import ModelField

class Data(BaseModel, ABC):
    """ base class for a Member """
    number: float

    def __init_subclass__(cls, **kwargs):
        name = 'tag'
        value = cls.__name__
        annotation = Literal[value]

        tag_field = ModelField.infer(name=name, value=value, annotation=annotation, class_validators=None, config=cls.__config__)
        cls.__fields__[name] = tag_field
        cls.__annotations__[name] = annotation


class DataA(Data):
    """ A type of Data"""
    pass


class DataB(Data):
    """ Another type of Data """
    pass

Upvotes: 6

Related Questions