Vinuta Hiremath
Vinuta Hiremath

Reputation: 21

Can I exclude a specific field from a Strawberry schema in FastAPI

Introduction:
I'm working on a FastAPI project using Strawberry for GraphQL schemas, and I'm trying to inherit from a base input schema. However, I want to exclude a specific field from the derived schema.

Current Implementation:
Here’s the current structure of my input schemas:

import strawberry
from typing import Optional

@strawberry.input
class GenericRequest:
    country: Optional[str] = strawberry.field(default=None, description="Country where the lanes are located")
    segment_id: Optional[str] = strawberry.field(default=None, description="Unique identifier for the segment")

@strawberry.input
class LaneGraphRequest(GenericRequest):
    lane_graph_id: Optional[str] = strawberry.field(default=None, description="Id of a respective lane graph")
    
    # I want to exclude the 'country' field from this schema

Issue:
I want to exclude the country field in the LaneGraphRequest class. However, I've found that redefining it with exclude=True does not work and leads to a TypeError. Is there a proper way to achieve this in Strawberry?

What I've Tried:

Conclusion:
Is there a recommended way to exclude a specific field from a Strawberry schema when inheriting from a base schema? Any insights or solutions would be greatly appreciated!

Upvotes: 0

Views: 130

Answers (1)

Kirill Ilichev
Kirill Ilichev

Reputation: 1279

Unfortunately there is no way to exclude fields from inherited input classes in strawberry. You could try this hack:

import strawberry
from typing import Optional

@strawberry.input
class GenericRequest:
    country: Optional[str] = strawberry.field(default=None, description="Country where the lanes are located")
    segment_id: Optional[str] = strawberry.field(default=None, description="Unique identifier for the segment")

@strawberry.input
class LaneGraphRequest(GenericRequest):
    lane_graph_id: Optional[str] = strawberry.field(default=None, description="Id of a respective lane graph")

    def __post_init__(self):
        if hasattr(self, 'country'):
            delattr(self, 'country')

But I would suggest to create generic class and inherit it:

@strawberry.input
class GenericRequest:
    segment_id: Optional[str] = strawberry.field(default=None, description="Unique identifier for the segment")


@strawberry.input
class CountryRequest(GenericRequest):
    country: Optional[str] = strawberry.field(default=None, description="Country where the lanes are located")


@strawberry.input
class LaneGraphRequest(GenericRequest):
    lane_graph_id: Optional[str] = strawberry.field(default=None, description="Id of a respective lane graph")


Upvotes: 0

Related Questions