Reputation: 43
I am using Pydantic to create a Timeseries model based on pandas Timestamp (start
, end
) and Timedelta (period
) objects. The model will be used by a small data analysis program with a number of configurations/scenarios.
I need to instantiate and validate aspects of the Timeseries model based on two bool (include_end_period
, allow_future
) and one optional int (max_periods
) config params. I then need to derive three new fields (timezone
, total_duration
, total_periods
) and perform some additional validations.
Due to several instances of needing to use one value when validating another, I was unable to achieve the desired result with the typical @validator
methods. In particular, I would often get a missing KeyError instead of an expected ValueError. The best solution I've found is to instead create one long @root_validator(pre=True)
method.
from pydantic import BaseModel, ValidationError, root_validator, conint
from pandas import Timestamp, Timedelta
class Timeseries(BaseModel):
start: Timestamp
end: Timestamp
period: Timedelta
include_end_period: bool = False
allow_future: bool = True
max_periods: conint(gt=0, strict=True) | None = None
# Derived values, do not pass as params
timezone: str | None
total_duration: Timedelta
total_periods: conint(gt=0, strict=True)
class Config:
extra = 'forbid'
validate_assignment = True
@root_validator(pre=True)
def _validate_model(cls, values):
# Validate input values
if values['start'] > values['end']:
raise ValueError('Start timestamp cannot be later than end')
if values['start'].tzinfo != values['end'].tzinfo:
raise ValueError('Start, end timezones do not match')
if values['period'] <= Timedelta(0):
raise ValueError('Period must be a positive amount of time')
# Set timezone
timezone = values['start'].tzname()
if 'timezone' in values and values['timezone'] != timezone:
raise ValueError('Timezone param does not match start timezone')
values['timezone'] = timezone
# Set duration (add 1 period if including end period)
total_duration = values['end'] - values['start']
if values['include_end_period']:
total_duration += values['period']
if 'total_duration' in values and values['total_duration'] != total_duration:
error_context = ' + 1 period (included end period)' if values['include_end_period'] else ''
raise ValueError(f'Duration param does not match end - start timestamps{error_context}')
values['total_duration'] = total_duration
# Set total_periods
periods_float: float = values['total_duration'] / values['period']
if periods_float != int(periods_float):
raise ValueError('Total duration not divisible by period length')
total_periods = int(periods_float)
if 'total_periods' in values and values['total_periods'] != total_periods:
raise ValueError('Total periods param does not match')
values['total_periods'] = total_periods
# Validate future
if not values['allow_future']:
# Get current timestamp to floor of period (subtract 1 period if including end period)
max_end: Timestamp = Timestamp.now(tz=values['timezone']).floor(freq=values['period'])
if values['include_end_period']:
max_end -= values['period']
if values['end'] > max_end:
raise ValueError('End period is future or current (incomplete)')
# Validate derived values
if values['total_duration'] < Timedelta(0):
raise ValueError('Total duration must be positive amount of time')
if values['max_periods'] and values['total_periods'] > values['max_periods']:
raise ValueError('Total periods exceeds max periods param')
return values
Instantiating the model in the happy case, using all config checks:
start = Timestamp('2023-03-01T00:00:00Z')
end = Timestamp('2023-03-02T00:00:00Z')
period = Timedelta('5min')
try:
ts = Timeseries(start=start, end=end, period=period,
include_end_period=True, allow_future=False, max_periods=10000)
print(ts.dict())
except ValidationError as e:
print(e)
Output:
"""
{'start': Timestamp('2023-03-01 00:00:00+0000', tz='UTC'),
'end': Timestamp('2023-03-02 00:00:00+0000', tz='UTC'),
'period': Timedelta('0 days 00:05:00'),
'include_end_period': True,
'allow_future': False,
'max_periods': 10000,
'timezone': 'UTC',
'total_duration': Timedelta('1 days 00:05:00'),
'total_periods': 289}
"""
Here I believe all my validation is working as expected, and delivers the expected ValueErrors instead of less helpful KeyErrors. Is this approach reasonable? It seems to go against the typical/recommended approach, and the @root_validator
documentation is quite brief compared to that of the @validator
.
I am also unsatisfied that I need to list the derived values (timezone
, total_duration
, total_periods
) at the top of the model. This implies they can/should be passed when instantiating, and requires extra logic in my validator script to check if they were passed, and if they match the derived values. By omitting them they would not benefit from the default validation of type, constraints, etc., and would force me to change the config to extra='allow'
. I would appreciate any tips on how to improve this.
Thank you!
Upvotes: 1
Views: 3747
Reputation: 18663
For testing purposes it is usually a good idea not to have such large functions. Even if you want to go for the root_validator
approach, you can (and IMO should) still divide up the logic into distinct, semantically sensible methods.
But I would suggest a slightly different approach altogether. Since timezone
, total_duration
and total_periods
are derived from other fields and that process is not very expensive, I would define properties for those instead of having them as fields.
This has the advantage that you don't need to compute their values in advance, which means you don't need the pre=True
approach and can utilize previously validated field values in field-specific validators.
Root validators still make sense, when you really need to ensure that many distinct fields taken together follow a certain logic.
Here is what I propose:
from collections.abc import Mapping
from typing import Any
from pandas import Timedelta, Timestamp
from pydantic import BaseModel, conint, root_validator, validator
AnyMap = Mapping[str, Any]
class Timeseries(BaseModel):
start: Timestamp
end: Timestamp
period: Timedelta
include_end_period: bool = False
allow_future: bool = True
max_periods: conint(gt=0, strict=True) | None = None
@validator("end")
def ensure_end_consistent_with_start(
cls,
v: Timestamp,
values: AnyMap,
) -> Timestamp:
val_start: Timestamp = values["start"]
if v < val_start:
raise ValueError("Start timestamp cannot be later than end")
if val_start.tzinfo != v.tzinfo:
raise ValueError("Start, end timezones do not match")
return v
@validator("period")
def ensure_period_is_positive(cls, v: Timedelta) -> Timedelta:
if v <= Timedelta(0):
raise ValueError("Period must be a positive amount of time")
return v
@validator("period")
def ensure_period_divides_duration(
cls,
v: Timedelta,
values: AnyMap,
) -> Timedelta:
duration: float = (values["end"] - values["start"]) / v
if int(duration) != duration:
raise ValueError("Total duration not divisible by period length")
return v
@root_validator
def ensure_end_is_allowed(cls, values: AnyMap) -> AnyMap:
if values["allow_future"]:
return values
val_period: Timedelta = values["period"]
val_end: Timestamp = values["end"]
max_end = Timestamp.now(tz=val_end.tzname()).floor(freq=val_period)
if values["include_end_period"]:
max_end -= val_period
if val_end > max_end:
raise ValueError("End period is future or current (incomplete)")
return values
@root_validator
def ensure_num_periods_allowed(cls, values: AnyMap) -> AnyMap:
periods = int((values["end"] - values["start"]) / values["period"])
if values["include_end_period"]:
periods += 1
if values["max_periods"] and periods > values["max_periods"]:
raise ValueError("Total periods exceeds max periods param")
return values
@property
def timezone(self) -> str | None:
return self.start.tzname()
@property
def total_duration(self) -> Timedelta:
total_duration = self.end - self.start
if self.include_end_period:
total_duration += self.period
return total_duration
@property
def total_periods(self) -> int:
return int(self.total_duration / self.period)
I guess it is a matter of personal preference, when to switch from field validators to root validators. For example, you could argue ensure_period_divides_duration
should be a root validator since it uses the values of three fields.
Your example data of course works with this model as well.
One thing to note is that the range constraint on total_periods
is redundant anyway, when you validate that end
is after start
(and that period
evenly divides the total duration).
You could also argue that even something as simple as total_duration
should not be a property. In that case you could make it a method called get_total_duration
.
But if you have those "derived" fields, you'll always run into the issue of having to check that whatever was passed by the user is consistent with the rest of the data.
I believe most of this headache will be gone, once Pydantic v2 drops, which promises computed fields (see the plan for v2).
Upvotes: 2