LJG
LJG

Reputation: 767

How to place specific constraints on the parameters of a Pydantic model?

How can I place specific constraints on the parameters of a Pydantic model? In particular, I would like:

The code I'm using is as follows:

from fastapi import FastAPI
from pydantic import BaseModel
from typing import Set
import uvicorn

app = FastAPI()


class Query(BaseModel):
    start_date: str
    end_date: str
    code: Set[str] = {
        "A1", "A2", "A3", "A4",
        "X1", "X2", "X3", "X4", "X5",
        "Y1", "Y2", "Y3"
    }
    cluster: Set[str] = {"C1", "C2", "C3"}

@app.post("/")
async def read_table(query: Query):
    return {"msg": query}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Upvotes: 4

Views: 3630

Answers (2)

se7en
se7en

Reputation: 870

You can use an Enum class or a Literal to validate the code and cluster and then use a root_validator for the date. Also type hint the date field with datetime instead of a string str. Like so:

from datetime import datetime
from enum import Enum
from typing import Literal

from pydantic import BaseModel, root_validator

"""using Literal to validater the code and cluster"""

class Query(BaseModel):
    start_date: datetime
    end_date: datetime
    code: Literal[
        "A1", "A2", "A3", "A4", "X1", "X2", "X3", "X4", "X5", "Y1", "Y2", "Y3"
    ]
    cluster: Literal["C1", "C2", "C3"]

    @root_validator()
    def validate_dates(cls, values):
        if datetime(year=2019, month=1, day=1) < values.get("start_date"):
            raise ValueError("Date cannot be earlier than 2019-01-01")

        if values.get("end_date") < values.get("start_date"):
            raise ValueError("end date cannot be earlier than start date")

        return values

if you wish to use Enum to validate the code and the cluster you will define the Enum class like so

class Cluster(Enum):
    C1 = "C1"
    C2 = "C3"
    C3 = "C3"


class Code(Enum):
    A1 = "A1"
    A2 = "A2"
    A3 = "A3"
    A4 = "A4"
    X1 = "X1"
    X2 = "X2"
    X3 = "X3"
    X4 = "X4"
    X5 = "X5"
    Y1 = "Y1"
    Y2 = "Y2"
    Y3 = "Y3"

and then replace the literals in the Query class with this

code: Code
cluster: Cluster

Upvotes: 2

MatsLindh
MatsLindh

Reputation: 52792

Pydantic has a set of constrained types that allows you to define specific constraints on values.

start_date must be at least "2019-01-01"

>>> class Foo(BaseModel):
...   d: condate(ge=datetime.date.fromisoformat('2019-01-01')

>>> Foo(d=datetime.date.fromisoformat('2018-01-12'))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "pydantic\main.py", line 342, in pydantic.main.BaseModel.__init__
pydantic.error_wrappers.ValidationError: 1 validation error for Foo
d
  ensure this value is greater than or equal to 2019-01-01 (type=value_error.number.not_ge; limit_value=2019-01-01)

>>> Foo(d=datetime.date.fromisoformat('2020-01-12'))
Foo(d=datetime.date(2020, 1, 12))

end_date must be greater than start_date

For more complicated rules, you can use a root validator:

from pydantic import BaseModel, root_validator
from datetime import date

class StartEnd(BaseModel):
    start: date
    end: date
    
    @root_validator
    def validate_dates(cls, values):
        if values['start'] > values['end']:
            raise ValueError('start is after end')
            
        return values
        

StartEnd(start=date.fromisoformat('2023-01-01'), end=date.fromisoformat('2022-01-01'))

Gives:

pydantic.error_wrappers.ValidationError: 1 validation error for StartEnd
__root__
  start is after end (type=value_error)

For code and cluster, you can use an Enum instead

from pydantic import BaseModel
from enum import Enum  # StrEnum in 3.11+


class ClusterEnum(str, Enum):
    C1 = "C1"
    C2 = "C2" 
    C3 = "C3"
    

class ClusterVal(BaseModel):
    cluster: ClusterEnum
        

print(ClusterVal(cluster='C3').cluster.value)
# outputs C3

Upvotes: 3

Related Questions