Asamaras
Asamaras

Reputation: 11

How to perform Serialization of a pydantic model with polymorphism?

I tried to serialize a pydantic model with an attribute that can be of class of multiple subclasses of a base class. However with a naive implementation the subclasses are serialized to the baseclass.

After reading this issue I wrote the following code but without any success:

from typing import Dict, Literal, Union

from pydantic import BaseModel, Field, RootModel


class NodeBase(BaseModel):
    id: str


class StartNode(NodeBase):
    type: Literal["start"] = "start"


class EndNode(NodeBase):
    type: Literal["end"] = "end"


class LLMNode(NodeBase):
    type: Literal["llm"] = "llm"
    name: str = Field(default_factory=lambda: id)
    purpose: str
    prompt: str
    model: Literal[
        "gpt-4o", "gpt4-turbo", "gpt-4", "gpt-3.5-turbo", "azure-gpt-3.5-turbo"
    ]


class NodeModel(RootModel):
    root: Union[StartNode, EndNode, LLMNode]


class Graph(BaseModel):
    nodes: Dict[str, NodeModel] = Field(default_factory=dict)

    def add_node(self, node: Union[StartNode, EndNode, LLMNode]) -> None:
        self.nodes[node.id] = NodeModel(root=node)


start_node = StartNode(id="start", type="start")
llm_node = LLMNode(id="llm", type="llm", purpose="test", prompt="test", model="gpt-4o")
end_node = EndNode(id="end", type="end")

# ========= Node tests =========
start_node_dict = start_node.model_dump()
llm_node_dict = llm_node.model_dump()
end_node_dict = end_node.model_dump()
# Is it possible to use model_validate with the base class?
start_node_from_dict = NodeBase.model_validate(start_node_dict)
llm_node_from_dict = NodeBase.model_validate(llm_node_dict)
end_node_from_dict = NodeBase.model_validate(end_node_dict)


assert start_node == start_node_from_dict
assert llm_node == llm_node_from_dict
assert end_node == end_node_from_dict


# ========= Graph tests =========
g = Graph()
g.add_node(start_node)
g.add_node(llm_node)
g.add_node(end_node)

g_dict = g.model_dump()
g_from_dict = Graph.model_validate(g_dict)
assert g == g_from_dict

give the following errors :

UserWarning: Pydantic serializer warnings:
  Expected `str` but got `builtin_function_or_method` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
Traceback (most recent call last):
  File "file.py", line 53, in <module>
    assert start_node == start_node_from_dict
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

I would like to be able to dump a Graph whether the added nodes are subclasses of Node such as StartNode or LLMNode and be able to deserialize a graph back where all nodes have the right types. In addition it would be great if I could also deserialize a subclass of Node without knowing which type directly with a NodeBase.model_validate(subclass_of_NodeBase_that_I_dont_know_the_type_of)

Upvotes: 1

Views: 858

Answers (2)

Asamaras
Asamaras

Reputation: 11

Thank you for your feedbacks, here is the solution I used for my problem.

# Create the NodeTypes union from the node types list
NodeTypes = Union[tuple(node_types)] # shouldn't contain NodeBase


class NodeModel(RootModel):
    root: NodeTypes

    @model_validator(mode="after")
    @classmethod
    def get_root(cls, obj):
        if hasattr(obj, "root"):
            return obj.root
        return obj

And have a different way of adding nodes to the graph

def add_node(self: Self, node: NodeBase) -> None:
    """Add a node to the graph.
    :param node: An instance of the Node class
    """
    self.nodes[node.id] = node

Upvotes: 0

Victor Egiazarian
Victor Egiazarian

Reputation: 1136

I'd suggest to use something like this. This is just an example of how mapping by type and custom function could help you:

from enum import Enum
from typing import Union, Type

from pydantic import BaseModel


class NodeType(str, Enum):
    START = "start"
    END = "end"
    LLM = "llm"


class NodeBase(BaseModel):
    type: NodeType


class StartNode(NodeBase):
    field: int

class EndNode(NodeBase):
    field2: int


class NodeModel(BaseModel):
    root: Union[StartNode, EndNode]

    @classmethod
    def from_dict(cls, data: dict) -> "NodeModel":
        type_map: dict[NodeType, Type[NodeBase]] = {
            NodeType.START: StartNode,
            NodeType.END: EndNode,
        }

        node_type = NodeType(data["type"])
        type_class = type_map[node_type]

        node = type_class.model_validate(obj=data)

        return cls(root=node)


NodeModel.from_dict(data={"field": 1, "type": "start"})  # root=StartNode(type=<NodeType.START: 'start'>, field=1)
NodeModel.from_dict(data={"field2": 2, "type": "end"})  # root=EndNode(type=<NodeType.END: 'end'>, field2=2)

Don't forget to add more validation! Hope it'll help!

Upvotes: 1

Related Questions