Reputation: 4168
I am building a FastAPI application, which has a lot of Pydantic models. Even though the application is working just fine, as expected the OpenAPI (Swagger UI) docs do not show the schema for all of these models under the Schemas
section.
Here are the contents of pydantic schemas.py
import socket
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, Field, validator
from typing_extensions import Literal
ResponseData = Union[List[Any], Dict[str, Any], BaseModel]
# Not visible in Swagger UI
class PageIn(BaseModel):
page_size: int = Field(default=100, gt=0)
num_pages: int = Field(default=1, gt=0, exclude=True)
start_page: int = Field(default=1, gt=0, exclude=True)
# visible under schemas on Swagger UI
class PageOut(PageIn):
total_records: int = 0
total_pages: int = 0
current_page: int = 1
class Config: # pragma: no cover
@staticmethod
def schema_extra(schema, model) -> None:
schema.get("properties").pop("num_pages")
schema.get("properties").pop("start_page")
# Not visible in Swagger UI
class BaseResponse(BaseModel):
host_: str = Field(default_factory=socket.gethostname)
message: Optional[str]
# Not visible in Swagger UI
class APIResponse(BaseResponse):
count: int = 0
location: Optional[str]
page: Optional[PageOut]
data: ResponseData
# Not visible in Swagger UI
class ErrorResponse(BaseResponse):
error: str
# visible under schemas on Swagger UI
class BaseFaultMap(BaseModel):
detection_system: Optional[str] = Field("", example="obhc")
fault_type: Optional[str] = Field("", example="disk")
team: Optional[str] = Field("", example="dctechs")
description: Optional[str] = Field(
"",
example="Hardware raid controller disk failure found. "
"Operation can continue normally,"
"but risk of data loss exist",
)
# Not visible in Swagger UI
class FaultQueryParams(BaseModel):
f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
hostname: Optional[str]
status: Literal["open", "closed", "all"] = Field("open")
created_by: Optional[str]
environment: Optional[str]
team: Optional[str]
fault_type: Optional[str]
detection_system: Optional[str]
inops_filters: Optional[str] = Field(None)
date_filter: Optional[str] = Field("",)
sort_by: Optional[str] = Field("created",)
sort_order: Literal["asc", "desc"] = Field("desc")
All of these models are actually being used in FastAPI paths to validate the request body. The FaultQueryParams
is a custom model, which I use to validate the request query params and is used like below:
query_args: FaultQueryParams = Depends()
The rest of the models are being used in conjunction with Body
field. I am not able to figure out why only some of the models are not visible in the Schemas
section while others are.
Also another thing I noticed about FaultQueryParams
is that the description, examples do not show up against the path endpoint even though they are defined in the model.
Edit 1:
I looked more into and realized that all of the models which are not visible in swagger UI are the ones that are not being used directly in path operations i.e., these models are not being used as response_model
or Body
types and are sort of helper models which are being used indirectly. So, it seems like FastAPI is not generating the schema for these models.
One exception to the above statement is query_args: FaultQueryParams = Depends()
which is being used directly in a path operation to map the Query
params for the endpoint against a custom model. This is a problem because swagger is not identifying the meta parameters like title
, description
, example
from the fields of this model & not showing on the UI which is important for the users of this endpoint.
Is there a way to trick FastAPI to generate schema for the custom model FaultQueryParams
just like it generates for Body
, Query
etc ?
Upvotes: 8
Views: 7163
Reputation: 34355
FastAPI will generate schemas for models that are used either as a Request Body or Response Model. When declaring query_args: FaultQueryParams = Depends()
(using Depends), your endpoint would not expect a request body
, but rather query
parameters; hence, FaultQueryParams
would not be included in the schemas of the OpenAPI docs.
To add additional schemas, you could extend/modify the OpenAPI schema. Example is given below (make sure to add the code for modifying the schema after all routes have been defined, i.e., at the end of your code).
class FaultQueryParams(BaseModel):
f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
hostname: Optional[str]
status: Literal["open", "closed", "all"] = Field("open")
...
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
return query_args
def get_extra_schemas():
return {
"FaultQueryParams": {
"title": "FaultQueryParams",
"type": "object",
"properties": {
"f_id": {
"title": "Fault ID",
"type": "integer",
"description": "id for the host",
"example": 12345
},
"hostname": {
"title": "Hostname",
"type": "string"
},
"status": {
"title": "Status",
"enum": [
"open",
"closed",
"all"
],
"type": "string",
"default": "open"
},
...
}
}
}
from fastapi.openapi.utils import get_openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="FastAPI",
version="1.0.0",
description="This is a custom OpenAPI schema",
routes=app.routes,
)
new_schemas = openapi_schema["components"]["schemas"]
new_schemas.update(get_extra_schemas())
openapi_schema["components"]["schemas"] = new_schemas
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
Instead of manually typing the schema for the extra models that you would like to add to the docs, you can have FastAPI do that for you by adding to your code an endpoint (which you would subsequently remove, after getting the schema) using that model as a request body or response model, for example:
@app.post('/predict')
def predict(query_args: FaultQueryParams):
return query_args
Then, you can get the generated JSON schema at http://127.0.0.1:8000/openapi.json, as described in the documentation. From there, you can either copy and paste the schema of the model to your code and use it directly (as shown in the get_extra_schema()
method above) or save it to a file and load the JSON data from the file, as demonstrated below:
import json
...
new_schemas = openapi_schema["components"]["schemas"]
with open('extra_schemas.json') as f:
extra_schemas = json.load(f)
new_schemas.update(extra_schemas)
openapi_schema["components"]["schemas"] = new_schemas
...
To declare metadata, such as description
, example
, etc, for your query parameter, you should define your parameter with Query
instead of Field
, and since you can't do that with Pydantic models, you either need to define the Query
parameter(s) directly in the endpoint or use a custom dependency class, as shown below:
from fastapi import FastAPI, Query, Depends
from typing import Optional
class FaultQueryParams:
def __init__(
self,
f_id: Optional[int] = Query(None, description="id for the host", example=12345)
):
self.f_id = f_id
app = FastAPI()
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
return query_args
The above could be re-written using the @dataclass
decorator as follows:
from fastapi import FastAPI, Query, Depends
from typing import Optional
from dataclasses import dataclass
@dataclass
class FaultQueryParams:
f_id: Optional[int] = Query(None, description="id for the host", example=12345)
app = FastAPI()
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
return query_args
There is no need for using a custom dependency class anymore, as FastAPI now allows using a Pydantic BaseModel
to define query parameters by wrapping the Query()
in a Field()
; hence, one should be able to use a Pydantic model to define multiple query parameters and declare metadata for them, i.e., description
, example
, etc. Related answers could be found here and here.
Upvotes: 7
Reputation: 4168
Thank to @Chris for the pointers which ultimately led me to use dataclasses for defining query params in bulk and it just worked fine.
@dataclass
class FaultQueryParams1:
f_id: Optional[int] = Query(None, description="id for the host", example=55555)
hostname: Optional[str] = Query(None, example="test-host1.domain.com")
status: Literal["open", "closed", "all"] = Query(
None, description="fetch open/closed or all records", example="all"
)
created_by: Optional[str] = Query(
None,
description="fetch records created by particular user",
example="user-id",
)
Upvotes: 2