Reputation: 21961
Is there a scikit-learn preprocesser I can use or implement to select a subset of rows from a pandas dataframe? I would prefer a preprocesser to do this since I want to build a pipeline with this as a step.
Upvotes: 6
Views: 1091
Reputation: 4529
You can use a FunctionTransformer
to do that. To a FunctionTransformer, you can pass any Callable that exposes the same interface as standard scikitlearn transform calls have. In code
import pandas as pd
from sklearn.preprocessing import FunctionTransformer
class RowSelector:
def __init__(self, rows:list[int]):
self._rows = rows
def __call__(self, X:pd.DataFrame, y=None) -> pd.DataFrame:
return X.iloc[self._rows,:]
selector = FunctionTransformer(RowSelector(rows=[1,3]))
df = pd.DataFrame({'a':range(4), 'b':range(4), 'c':range(4)})
selector.fit_transform(df)
#Returns
a b c
1 1 1 1
3 3 3 3
Not that, I have used a callable object to track some state, i.e. the rows to be selected. This is not necessary and could be solved differently.
The cool thing is that it returns a data frame, so if you have it as the first step of your pipeline, you can also combine it with a subsequent column transformer (if needed of course)
Upvotes: 5