Rohit
Rohit

Reputation: 4168

OpenAPI is missing schemas for some of the Pydantic models in FastAPI app

I am building a FastAPI application, which has a lot of Pydantic models. Even though the application is working just fine, as expected the OpenAPI (Swagger UI) docs do not show the schema for all of these models under the Schemas section.

Here are the contents of pydantic schemas.py

import socket
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel, Field, validator
from typing_extensions import Literal

ResponseData = Union[List[Any], Dict[str, Any], BaseModel]


# Not visible in Swagger UI
class PageIn(BaseModel):
    page_size: int = Field(default=100, gt=0)
    num_pages: int = Field(default=1, gt=0, exclude=True)
    start_page: int = Field(default=1, gt=0, exclude=True)

# visible under schemas on Swagger UI
class PageOut(PageIn):
    total_records: int = 0
    total_pages: int = 0
    current_page: int = 1

    class Config:  # pragma: no cover
        @staticmethod
        def schema_extra(schema, model) -> None:
            schema.get("properties").pop("num_pages")
            schema.get("properties").pop("start_page")


# Not visible in Swagger UI
class BaseResponse(BaseModel):
    host_: str = Field(default_factory=socket.gethostname)
    message: Optional[str]


# Not visible in Swagger UI
class APIResponse(BaseResponse):
    count: int = 0
    location: Optional[str]
    page: Optional[PageOut]
    data: ResponseData


# Not visible in Swagger UI
class ErrorResponse(BaseResponse):
    error: str


# visible under schemas on Swagger UI
class BaseFaultMap(BaseModel):
    detection_system: Optional[str] = Field("", example="obhc")
    fault_type: Optional[str] = Field("", example="disk")
    team: Optional[str] = Field("", example="dctechs")
    description: Optional[str] = Field(
        "",
        example="Hardware raid controller disk failure found. "
        "Operation can continue normally,"
        "but risk of data loss exist",
    )



# Not visible in Swagger UI
class FaultQueryParams(BaseModel):
    f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
    hostname: Optional[str]
    status: Literal["open", "closed", "all"] = Field("open")
    created_by: Optional[str]
    environment: Optional[str]
    team: Optional[str]
    fault_type: Optional[str]
    detection_system: Optional[str]
    inops_filters: Optional[str] = Field(None)
    date_filter: Optional[str] = Field("",)
    sort_by: Optional[str] = Field("created",)
    sort_order: Literal["asc", "desc"] = Field("desc")

All of these models are actually being used in FastAPI paths to validate the request body. The FaultQueryParams is a custom model, which I use to validate the request query params and is used like below:

query_args: FaultQueryParams = Depends()

The rest of the models are being used in conjunction with Body field. I am not able to figure out why only some of the models are not visible in the Schemas section while others are.

Also another thing I noticed about FaultQueryParams is that the description, examples do not show up against the path endpoint even though they are defined in the model.

Edit 1:

I looked more into and realized that all of the models which are not visible in swagger UI are the ones that are not being used directly in path operations i.e., these models are not being used as response_model or Body types and are sort of helper models which are being used indirectly. So, it seems like FastAPI is not generating the schema for these models.

One exception to the above statement is query_args: FaultQueryParams = Depends() which is being used directly in a path operation to map the Query params for the endpoint against a custom model. This is a problem because swagger is not identifying the meta parameters like title, description, example from the fields of this model & not showing on the UI which is important for the users of this endpoint.

Is there a way to trick FastAPI to generate schema for the custom model FaultQueryParams just like it generates for Body, Query etc ?

Upvotes: 8

Views: 7163

Answers (2)

Chris
Chris

Reputation: 34355

FastAPI will generate schemas for models that are used either as a Request Body or Response Model. When declaring query_args: FaultQueryParams = Depends() (using Depends), your endpoint would not expect a request body, but rather query parameters; hence, FaultQueryParams would not be included in the schemas of the OpenAPI docs.

To add additional schemas, you could extend/modify the OpenAPI schema. Example is given below (make sure to add the code for modifying the schema after all routes have been defined, i.e., at the end of your code).

class FaultQueryParams(BaseModel):
    f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
    hostname: Optional[str]
    status: Literal["open", "closed", "all"] = Field("open")
    ...
    
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
    return query_args

def get_extra_schemas():
    return {
              "FaultQueryParams": {
                "title": "FaultQueryParams",
                "type": "object",
                "properties": {
                  "f_id": {
                    "title": "Fault ID",
                    "type": "integer",
                    "description": "id for the host",
                    "example": 12345
                  },
                  "hostname": {
                    "title": "Hostname",
                    "type": "string"
                  },
                  "status": {
                    "title": "Status",
                    "enum": [
                      "open",
                      "closed",
                      "all"
                    ],
                    "type": "string",
                    "default": "open"
                  },
                   ...
                }
              }
            }

from fastapi.openapi.utils import get_openapi

def custom_openapi():
    if app.openapi_schema:
        return app.openapi_schema
    openapi_schema = get_openapi(
        title="FastAPI",
        version="1.0.0",
        description="This is a custom OpenAPI schema",
        routes=app.routes,
    )
    new_schemas = openapi_schema["components"]["schemas"]
    new_schemas.update(get_extra_schemas())
    openapi_schema["components"]["schemas"] = new_schemas
    
    app.openapi_schema = openapi_schema
    return app.openapi_schema


app.openapi = custom_openapi

Some Helpful Notes

Note 1

Instead of manually typing the schema for the extra models that you would like to add to the docs, you can have FastAPI do that for you by adding to your code an endpoint (which you would subsequently remove, after getting the schema) using that model as a request body or response model, for example:

@app.post('/predict') 
def predict(query_args: FaultQueryParams):
    return query_args

Then, you can get the generated JSON schema at http://127.0.0.1:8000/openapi.json, as described in the documentation. From there, you can either copy and paste the schema of the model to your code and use it directly (as shown in the get_extra_schema() method above) or save it to a file and load the JSON data from the file, as demonstrated below:

import json
...

new_schemas = openapi_schema["components"]["schemas"]

with open('extra_schemas.json') as f:    
    extra_schemas = json.load(f)
    
new_schemas.update(extra_schemas)   
openapi_schema["components"]["schemas"] = new_schemas

...

Note 2

To declare metadata, such as description, example, etc, for your query parameter, you should define your parameter with Query instead of Field, and since you can't do that with Pydantic models, you either need to define the Query parameter(s) directly in the endpoint or use a custom dependency class, as shown below:

from fastapi import FastAPI, Query, Depends
from typing import Optional

class FaultQueryParams:
    def __init__(
        self,
        f_id: Optional[int] = Query(None, description="id for the host", example=12345)

    ):
        self.f_id = f_id

app = FastAPI()

@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
    return query_args

The above could be re-written using the @dataclass decorator as follows:

from fastapi import FastAPI, Query, Depends
from typing import Optional
from dataclasses import dataclass

@dataclass
class FaultQueryParams:
    f_id: Optional[int] = Query(None, description="id for the host", example=12345)

app = FastAPI()

@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
    return query_args

Update

There is no need for using a custom dependency class anymore, as FastAPI now allows using a Pydantic BaseModel to define query parameters by wrapping the Query() in a Field(); hence, one should be able to use a Pydantic model to define multiple query parameters and declare metadata for them, i.e., description, example, etc. Related answers could be found here and here.

Upvotes: 7

Rohit
Rohit

Reputation: 4168

Thank to @Chris for the pointers which ultimately led me to use dataclasses for defining query params in bulk and it just worked fine.

@dataclass
class FaultQueryParams1:
    f_id: Optional[int] = Query(None, description="id for the host", example=55555)
    hostname: Optional[str] = Query(None, example="test-host1.domain.com")
    status: Literal["open", "closed", "all"] = Query(
        None, description="fetch open/closed or all records", example="all"
    )
    created_by: Optional[str] = Query(
        None,
        description="fetch records created by particular user",
        example="user-id",
    )

Upvotes: 2

Related Questions