Reputation: 1470
Suppose I have the following function:
def foo(df: pd.DataFrame) -> pd.DataFrame:
x = df["x"]
y = df["y"]
df["xy"] = x * y
return df
Is there a way I could hint that my function is accepting a data frame that must have the "x" and "y" column and that it will return a data frame with the "x", "y" and "xy" columns, instead of just a general data frame?
Upvotes: 10
Views: 5404
Reputation: 1470
Okay, so, I'm not sure if this is the correct way of implementing it, but seems to work for me. If you see any mistakes or alternatives let me know and I can edit the response but my solution was basically creating a new class and implementing the __class_getitem__
method as seen in the Pep 560, this was my final code:
from typing import List
import pandas as pd
GenericAlias = type(List[str])
class MyDataFrame(pd.DataFrame):
__class_getitem__ = classmethod(GenericAlias)
def foo(df: MyDataFrame[["x", "y"]]) -> MyDataFrame[["x", "y", "xy"]]:
df["xy"] = df["x"] * df["y"]
return df
The Pandera library is a better alternative that I did not know about when I posted the question.
import pandas as pd
from pandera.typing import Series, DataFrame
class InputSchema(pa.DataFrameModel):
x: Series[float]
y: Series[float]
class ReturnSchema(pa.DataFrameModel):
x: Series[float]
y: Series[float]
xy: Series[float]
def foo(df: DataFrame[InputSchema]) -> DataFrame[ReturnSchema]:
x = df['x']
y = df['y']
df['xy'] = x * y
return df
Upvotes: 6