Reputation: 623
I have been struggling to refactor this marshmellow schema to pydantic 2. Problem is with AttributeValue field. When I refactor this to pydantic, I always get strange errors regarding this field. This attribute field is dynamic that can accept many values as you can see it in _get_field
function. I cannot seem to find correct way to define similar field in pydantic. Any suggestions?
import copy
from dataclasses import dataclass
from marshmallow import fields, schema, ValidationError, validates_schema, post_load, Schema
@dataclass(frozen=True, kw_only=True)
class Attribute:
"""
A dataclass for storing the attributes of a segment.
"""
id: str
value: str | list[str]
def get_validation_error_kwargs(field_name=None):
kwargs = {}
if field_name is not None:
kwargs["field_name"] = field_name
return kwargs
def make_single_option_validator(options, field_name=None):
validation_error_kwargs = get_validation_error_kwargs(field_name)
def validate(value):
if value not in options:
raise ValidationError(
f'"{value}" is not a valid option', **validation_error_kwargs
)
return validate
def make_multiple_option_validator(options, field_name=None):
def validate(values):
invalid_values = set(values).difference(options)
if invalid_values:
invalid = ", ".join(sorted(f'"{value}"' for value in invalid_values))
raise ValueError(f"The following values are not valid options: {invalid} ")
return validate
class AttributeValue(fields.Field):
"""
A field for validating and serialising attribute values.
"""
type_fields = {
"single": fields.String(),
"multi": fields.List(fields.String),
"boolean": fields.Boolean(),
"range": fields.Integer(),
"account": fields.String(),
"supplier": fields.List(fields.String),
}
option_validator_factories = {
"single": make_single_option_validator,
"account": make_single_option_validator,
"multi": make_multiple_option_validator,
"supplier": make_multiple_option_validator,
}
# Used to produce ``None`` when the attribute ID isn't valid and a
# type-specific field can't be found. This is required because
# even though the ID is validated elsewhere, this field's
# validation will always run.
null_field = fields.Constant(None)
def _get_field(self, attribute_id):
manifest = self.parent.manifest # type: ignore[attr-defined]
manifest_attribute = manifest.attributes.get(attribute_id)
if manifest_attribute:
attribute_type = manifest_attribute["type"]
field = self.type_fields[attribute_type]
# Check if the attribute type requires that its value is
# validated against a set of options
if (
attribute_type in self.option_validator_factories
and "option_labels" in manifest_attribute
):
field = copy.deepcopy(field)
if not self.parent.context.get("skip_attributes_validation", False):
# Proceed with attribute's option validation
# only if the campaign isn't finalized (frozen).
# For finalized (frozen) campaing,
# there will be 'skip_attributes_validation' == True.
field.validators = [
self.option_validator_factories[attribute_type](
manifest_attribute["option_labels"]
)
]
return field
return self.null_field
def _serialize(self, value, attr, obj, **kwargs):
field = self._get_field(obj.id)
return field.serialize(attr, obj)
def _deserialize(self, value, attr, data, **kwargs):
field = self._get_field(data["id"])
return field.deserialize(value, attr, data)
class AttributeSchema(Schema):
"""
A schema for validating and serialising attributes.
"""
id = fields.String()
value = AttributeValue(allow_none=True)
@property
def manifest(self):
return self.context["manifest"]
@validates_schema
def validate_values(self, data, **_):
if data.get("value") is None:
raise ValidationError(f"{data['id']!r} value may not be null.")
@post_load
def make_attribute(self, data, **kwargs):
return Attribute(**data)
And this is example test case:
from datetime import date
import pytest
from marshmallow import ValidationError
from .pricing import Attribute, DateRange
from .schema import AttributeSchema
@pytest.fixture
def schema(manifest):
"""
Return a ``AttributeSchema`` instance bound to Manifest.
"""
return AttributeSchema(context={"manifest": manifest})
TYPE_INPUTS = [
# Single
(
{"id": "industry", "value": "appliances"},
Attribute(id="industry", value="appliances"),
),
# Multi
(
{"id": "day_part", "value": ["am", "pm"]},
Attribute(id="day_part", value=["am", "pm"]),
),
# Boolean
({"id": "restricted", "value": True}, Attribute(id="restricted", value=True)),
]
@pytest.mark.parametrize("payload, expected_attribute", TYPE_INPUTS)
def test_deserialisiaton_succeeds(schema, payload, expected_attribute):
"""
Each supported attribute value type should be appropriately
deserialised.
"""
attribute = schema.load(payload)
assert attribute == expected_attribute
As you can see attribute value can be one of multiple types and should be validated correctly.
Any suggestions or guidance will be appreciated.
Upvotes: 0
Views: 28