Sharmiko
Sharmiko

Reputation: 623

Refactoring marshmellow schema to pydantic 2

I have been struggling to refactor this marshmellow schema to pydantic 2. Problem is with AttributeValue field. When I refactor this to pydantic, I always get strange errors regarding this field. This attribute field is dynamic that can accept many values as you can see it in _get_field function. I cannot seem to find correct way to define similar field in pydantic. Any suggestions?

import copy

from dataclasses import dataclass
from marshmallow import fields, schema, ValidationError, validates_schema, post_load, Schema


@dataclass(frozen=True, kw_only=True)
class Attribute:
    """
    A dataclass for storing the attributes of a segment.
    """

    id: str
    value: str | list[str]



def get_validation_error_kwargs(field_name=None):
    kwargs = {}
    if field_name is not None:
        kwargs["field_name"] = field_name

    return kwargs


def make_single_option_validator(options, field_name=None):
    validation_error_kwargs = get_validation_error_kwargs(field_name)

    def validate(value):
        if value not in options:
            raise ValidationError(
                f'"{value}" is not a valid option', **validation_error_kwargs
            )

    return validate


def make_multiple_option_validator(options, field_name=None):
    def validate(values):
        invalid_values = set(values).difference(options)
        if invalid_values:
            invalid = ", ".join(sorted(f'"{value}"' for value in invalid_values))
            raise ValueError(f"The following values are not valid options: {invalid} ")

    return validate


class AttributeValue(fields.Field):
    """
    A field for validating and serialising attribute values.
    """

    type_fields = {
        "single": fields.String(),
        "multi": fields.List(fields.String),
        "boolean": fields.Boolean(),
        "range": fields.Integer(),
        "account": fields.String(),
        "supplier": fields.List(fields.String),
    }
    option_validator_factories = {
        "single": make_single_option_validator,
        "account": make_single_option_validator,
        "multi": make_multiple_option_validator,
        "supplier": make_multiple_option_validator,
    }
    # Used to produce ``None`` when the attribute ID isn't valid and a
    # type-specific field can't be found. This is required because
    # even though the ID is validated elsewhere, this field's
    # validation will always run.
    null_field = fields.Constant(None)

    def _get_field(self, attribute_id):
        manifest = self.parent.manifest  # type: ignore[attr-defined]
        manifest_attribute = manifest.attributes.get(attribute_id)
        if manifest_attribute:
            attribute_type = manifest_attribute["type"]
            field = self.type_fields[attribute_type]

            # Check if the attribute type requires that its value is
            # validated against a set of options
            if (
                attribute_type in self.option_validator_factories
                and "option_labels" in manifest_attribute
            ):
                field = copy.deepcopy(field)

                if not self.parent.context.get("skip_attributes_validation", False):
                    # Proceed with attribute's option validation
                    # only if the campaign isn't finalized (frozen).
                    # For finalized (frozen) campaing,
                    # there will be 'skip_attributes_validation' == True.
                    field.validators = [
                        self.option_validator_factories[attribute_type](
                            manifest_attribute["option_labels"]
                        )
                    ]

            return field


        return self.null_field

    def _serialize(self, value, attr, obj, **kwargs):
        field = self._get_field(obj.id)
        return field.serialize(attr, obj)

    def _deserialize(self, value, attr, data, **kwargs):
        field = self._get_field(data["id"])
        return field.deserialize(value, attr, data)


class AttributeSchema(Schema):
    """
    A schema for validating and serialising attributes.
    """

    id = fields.String()
    value = AttributeValue(allow_none=True)

    @property
    def manifest(self):
        return self.context["manifest"]

    @validates_schema
    def validate_values(self, data, **_):
        if data.get("value") is None:
            raise ValidationError(f"{data['id']!r} value may not be null.")

    @post_load
    def make_attribute(self, data, **kwargs):
        return Attribute(**data)
    

And this is example test case:

from datetime import date

import pytest
from marshmallow import ValidationError

from .pricing import Attribute, DateRange
from .schema import AttributeSchema


@pytest.fixture
def schema(manifest):
    """
    Return a ``AttributeSchema`` instance bound to Manifest.
    """
    return AttributeSchema(context={"manifest": manifest})


TYPE_INPUTS = [
    # Single
    (
        {"id": "industry", "value": "appliances"},
        Attribute(id="industry", value="appliances"),
    ),
    # Multi
    (
        {"id": "day_part", "value": ["am", "pm"]},
        Attribute(id="day_part", value=["am", "pm"]),
    ),
    # Boolean
    ({"id": "restricted", "value": True}, Attribute(id="restricted", value=True)),
]


@pytest.mark.parametrize("payload, expected_attribute", TYPE_INPUTS)
def test_deserialisiaton_succeeds(schema, payload, expected_attribute):
    """
    Each supported attribute value type should be appropriately
    deserialised.
    """
    attribute = schema.load(payload)

    assert attribute == expected_attribute

As you can see attribute value can be one of multiple types and should be validated correctly.

Any suggestions or guidance will be appreciated.

Upvotes: 0

Views: 28

Answers (0)

Related Questions