Jérôme
Jérôme

Reputation: 14714

Define fields programmatically in Marshmallow Schema

Say I have a Schema like this:

class MySchema(Schema):

    field_1 = Float()
    field_2 = Float()
    ...
    field_42 = Float()

Is there a way to add those fields programmatically to the class?

Something like this:

class MyClass(BaseClass):

    FIELDS = ('field_1', 'field_2',..., 'field_42')

    for field in FIELDS:
        setattr(?, field, Float())  # What do I replace this "?" with?

I've seen posts about adding attributes dynamically to class instances, but this is different because

The same question might apply to other model definition libraries, like ODM/ORM (uMongo/MongoEngine, SQL Alchemy,...)

Upvotes: 8

Views: 14794

Answers (6)

Navid Khan
Navid Khan

Reputation: 1169

If you are using marshmallow 3 or later, you can take advantage of the Schema.from_dict method.

from marshmallow import Schema, fields

MySchema = Schema.from_dict(
    {
        "id": fields.Str(dump_only=True),
        "content": fields.Str(required=True),
    }
)

If the shape of your schema needs to change at run time, you can do something like this:

my_schema = {
  "id": fields.Str(dump_only=True),
}

if (some_condition):
    my_schema["additional_field"] = fields.Str(dump_only=True)

MySchema = Schema.from_dict(my_schema)

This example is illustrated with more detail in this blog post.

@Panic also shares this example, but the answer is incomplete.

Upvotes: 1

Jérôme
Jérôme

Reputation: 14714

I managed to do it by subclassing the default metaclass:

class MySchemaMeta(SchemaMeta):

    @classmethod
    def get_declared_fields(mcs, klass, cls_fields, inherited_fields, dict_cls):
        fields = super().get_declared_fields(klass, cls_fields, inherited_fields, dict_cls)
        FIELDS = ('field_1', 'field_2',..., 'field_42')
        for field in FIELDS:
            fields.update({field: Float()})
        return fields

class MySchema(Schema, metaclass=MySchemaMeta):

    class Meta:
        strict = True

I made this more generic:

class DynamicSchemaOpts(SchemaOpts):

    def __init__(self, meta):
        super().__init__(meta)
        self.auto_fields = getattr(meta, 'auto_fields', [])


class DynamicSchemaMeta(SchemaMeta):

    @classmethod
    def get_declared_fields(mcs, klass, cls_fields, inherited_fields, dict_cls):

        fields = super().get_declared_fields(klass, cls_fields, inherited_fields, dict_cls)

        for auto_field_list in klass.opts.auto_fields:
            field_names, field = auto_field_list
            field_cls = field['cls']
            field_args = field.get('args', [])
            field_kwargs = field.get('kwargs', {})
            for field_name in field_names:
                fields.update({field_name: field_cls(*field_args, **field_kwargs)})

        return fields


class MySchema(Schema, metaclass=DynamicSchemaMeta):

    OPTIONS_CLASS = DynamicSchemaOpts

    class Meta:
        strict = True
        auto_fields = [
            (FIELDS,
             {'cls': Float}),
        ]

I didn't write

class Meta:
    strict = True
    auto_fields = [
        (FIELDS, Float()),
    ]

because then all those fields would share the same Field instance.

The Field and its args/kwargs must be specified separately:

    class Meta:
        strict = True
        auto_fields = [
            (FIELDS,
             {'cls': Nested,
              'args': (MyEmbeddedSchema),
              'kwargs': {'required': True}
             }),
        ]

I don't have any example use case failing due to several fields sharing the same instance, but it doesn't sound safe. If this precaution is useless then the code could be simplified and made more readable:

    class Meta:
        strict = True
        auto_fields = [
            (FIELDS, Nested(MyEmbeddedSchema, required=True)),
        ]

Obviously, this answer is specific to Marshmallow and does not apply to other ODM/ORM libraries.

Upvotes: 2

Panic
Panic

Reputation: 2405

You can use marshmallow.Schema.from_dict to generate a mixin schema.

class MySchema(
    ma.Schema.from_dict({f"field_{i}": ma.fields.Int() for i in range(1, 4)})
):
    field_4 = ma.fields.Str()

Upvotes: 2

canary_in_the_data_mine
canary_in_the_data_mine

Reputation: 2393

The following method works for me.

I've demonstrated it using Marshmallow-SQLAlchemy because I'm not sure something like this is needed for plain Marshmallow anymore -- with version 3.0.0 it's pretty straightforward to programmatically create a schema using from_dict. But you could certainly use these concepts with plain Marshmallow.

Here, I use Marshmallow-SQLAlchemy to infer most of the schema, and then apply special treatment to a couple of the fields programmatically.

import enum

from marshmallow_enum import EnumField
from marshmallow_sqlalchemy import ModelSchema
from sqlalchemy import Column
from sqlalchemy import Enum
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.ext.declarative import declarative_base


BaseResource = declarative_base()


class CustomEnum(enum.Enum):
    VALUE_1 = "the first value"
    VALUE_2 = "the second value"


class ExampleResource(BaseResource):
    __tablename__ = "example_resource"

    id = Column(Integer, primary_key=True)
    enum_field = Column(Enum(CustomEnum), nullable=False)
    title = Column(String)
    string_two = Column(String)

    def __init__(self, **kwargs):
        super(ExampleResource, self).__init__(**kwargs)


def generate_schema(class_, serialization_fields, serialization_fields_excluded):
    """A method for programmatically generating schema.

    Args:
        class_ (class): the class to generate the schema for
        serialization_fields (dict): key-value pairs with the field name and its Marshmallow `Field`
        serialization_fields_excluded (tuple): fields to exclude

    Returns:
        schema (marshmallow.schema.Schema): the generated schema

    """

    class MarshmallowBaseSchema(object):
        pass

    if serialization_fields is not None:
        for field, marshmallow_field in serialization_fields.items():
            setattr(MarshmallowBaseSchema, field, marshmallow_field)

    class MarshmallowSchema(MarshmallowBaseSchema, ModelSchema):
        class Meta:
            model = class_
            exclude = serialization_fields_excluded

    return MarshmallowSchema


generated_schema = generate_schema(
    class_=ExampleResource,
    # I'm using a special package to handle the field `enum_field`
    serialization_fields=dict(enum_field=EnumField(CustomEnum, by_value=True, required=True)),
    # I'm excluding the field `string_two`
    serialization_fields_excluded=("string_two",),
)

example_resource = ExampleResource(
    id=1,
    enum_field=CustomEnum.VALUE_2,
    title="A Title",
    string_two="This will be ignored."
)
print(generated_schema().dump(example_resource))
# {'title': 'A Title', 'id': 1, 'enum_field': 'the second value'}

It's necessary to define MarshmallowBaseSchema as a plain object, add all the fields, and then inherit from that class because the Marshmallow Schema initializes all the fields on init (in particular, _init_fields()), so this inheritance pattern makes sure all the fields are there at that time.

Upvotes: 2

radzak
radzak

Reputation: 3118

The class Meta paradigm allows you to specify which attributes you want to serialize. Marshmallow will choose an appropriate field type based on the attribute’s type.

class MySchema(Schema):
    class Meta:
        fields = ('field_1', 'field_2', ..., 'field_42')
    ...

Refactoring: Implicit Field Creation

Upvotes: 2

Maxim Kulkin
Maxim Kulkin

Reputation: 2788

All you need to do is to use type() function to build your class with any attributes you want:

MySchema = type('MySchema', (marshmallow.Schema,), {
    attr: marshmallow.fields.Float()
    for attr in FIELDS
})

You can even have different types of fields there:

fields = {}
fields['foo'] = marshmallow.fields.Float()
fields['bar'] = marshmallow.fields.String()
MySchema = type('MySchema', (marshmallow.Schema,), fields)

or as a base for your customizations:

class MySchema(type('_MySchema', (marshmallow.Schema,), fields)):
    @marshmallow.post_dump
    def update_something(self, data):
        pass

Upvotes: 29

Related Questions