user308827
user308827

Reputation: 21961

using scikit-learn preprocesser to select subset of rows in pandas dataframe

Is there a scikit-learn preprocesser I can use or implement to select a subset of rows from a pandas dataframe? I would prefer a preprocesser to do this since I want to build a pipeline with this as a step.

Upvotes: 6

Views: 1091

Answers (1)

Simon Hawe
Simon Hawe

Reputation: 4529

You can use a FunctionTransformer to do that. To a FunctionTransformer, you can pass any Callable that exposes the same interface as standard scikitlearn transform calls have. In code

import pandas as pd
from sklearn.preprocessing import FunctionTransformer

class RowSelector:
    def __init__(self, rows:list[int]):
        self._rows = rows

    def __call__(self, X:pd.DataFrame, y=None) -> pd.DataFrame:
        return X.iloc[self._rows,:]

selector = FunctionTransformer(RowSelector(rows=[1,3]))
df = pd.DataFrame({'a':range(4), 'b':range(4), 'c':range(4)})
selector.fit_transform(df)
#Returns
   a  b  c
1  1  1  1
3  3  3  3

Not that, I have used a callable object to track some state, i.e. the rows to be selected. This is not necessary and could be solved differently.

The cool thing is that it returns a data frame, so if you have it as the first step of your pipeline, you can also combine it with a subsequent column transformer (if needed of course)

Upvotes: 5

Related Questions