plasmon360
plasmon360

Reputation: 4199

Parsing list of different models with Pydantic

I am trying to create a pydantic model that can parse json string with a list of different kind of models. See example below

test_object = """
{
  "distributions":
    [
      {
        "param_name": "test1",
        "attributes":
          {
            "distribution_name": "UniformDistribution",
            "low": "1.0",
            "high": "2.0"
          }
      },
      {
        "param_name": "test2",
        "attributes":
          {
            "distribution_name": "UniformDistribution",
            "low": "1.0",
            "high":"2.0"
          }
      },
      {
        "param_name": "test3",
        "attributes":
          {
            "distribution_name": "IntUniformDistribution",
            "low": "1",
            "high": "2",
            "q": 4
          }
      },
      {
        "param_name": "test4",
        "attributes":
          {
            "distribution_name": "DiscreteUniformDistribution",
            "low": "1.0",
            "high": "2.0",
            "step": ".1"
          }
      }
    ]
}
"""

This is my pydantic model to parse the above string.

from pydantic import BaseModel, conint, confloat, ValidationError
from typing import List, Literal, Union


class DiscreteUniformDistribution(BaseModel):
    distribution_name: Literal["DiscreteUniformDistribution"]
    low: float
    high: float
    q: confloat(gt=0)

class IntUniformDistribution(BaseModel):
    distribution_name: Literal["IntUniformDistribution"]
    low: int
    high: int
    step: conint(gt=0)

class UniformDistribution(BaseModel):
    distribution_name: Literal["UniformDistribution"]
    low: float
    high: float

class Distribution(BaseModel):
    param_name: str
    attributes: Union[
        DiscreteUniformDistribution,
        IntUniformDistribution,
        UniformDistribution,
    ]

class Distributions(BaseModel):
    distributions: List[Distribution]


try:
    Distributions.parse_raw(test_object)
except ValidationError as e:
    print(e)

However I am getting an error

9 validation errors for Distributions
distributions -> 2 -> attributes -> distribution_name
  unexpected value; permitted: 'DiscreteUniformDistribution' (type=value_error.const; given=IntUniformDistribution; permitted=('DiscreteUniformDistribution',))
distributions -> 2 -> attributes -> step
  field required (type=value_error.missing)
distributions -> 2 -> attributes -> distribution_name
  unexpected value; permitted: 'UniformDistribution' (type=value_error.const; given=IntUniformDistribution; permitted=('UniformDistribution',))
distributions -> 3 -> attributes -> q
  field required (type=value_error.missing)
distributions -> 3 -> attributes -> distribution_name
  unexpected value; permitted: 'IntUniformDistribution' (type=value_error.const; given=DiscreteUniformDistribution; permitted=('IntUniformDistribution',))
distributions -> 3 -> attributes -> low
  value is not a valid integer (type=type_error.integer)
distributions -> 3 -> attributes -> high
  value is not a valid integer (type=type_error.integer)
distributions -> 3 -> attributes -> step
  value is not a valid integer (type=type_error.integer)
distributions -> 3 -> attributes -> distribution_name
  unexpected value; permitted: 'UniformDistribution' (type=value_error.const; given=DiscreteUniformDistribution; permitted=('UniformDistribution',))

If works fine if there is only element in distributions list, like shown below.

test_object = """
{
  "distributions": 
    [
      {
        "param_name": "test1",
        "attributes":
          {
            "distribution_name": "UniformDistribution",
            "low": "1.0",
            "high": "2.0"
          }
      }
    ]
}
"""

I am new to pydantic. I suspect the error is do with my improper usage of Union.

Upvotes: 1

Views: 2998

Answers (1)

Paul P
Paul P

Reputation: 3927

Your usage of Union[] looks good, however, there is a typo in your model definitions.

You need to swap q in DiscreteUniformDistribution() with step in IntUniformDistribution() (only the field names, not the types), i.e.:

class DiscreteUniformDistribution(BaseModel):
    distribution_name: Literal['DiscreteUniformDistribution']
    low: float
    high: float
    step: confloat(gt=0)
#   ^^^^
#   This is called q in your definition

class IntUniformDistribution(BaseModel):
    distribution_name: Literal['IntUniformDistribution']
    low: int
    high: int
    q: conint(gt=0)
#   ^
#   This is called step in your definition

Upvotes: 1

Related Questions