jmich738
jmich738

Reputation: 1675

Creating custom transformer with sklearn - missing required positional argument error

I'm trying to create a custom transformer that will split a column into multiple columns and I want to provide the delimiter also.

Here is the code I made to create the transformer

class StringSplitTransformer(BaseEstimator, TransformerMixin):
def __init__(self, cols = None):
    self.cols = cols
def transform(self,df,delim):
    X = df.copy()
    for col in self.cols:
        X = pd.concat([X,X[col].str.split(delim,expand = True)], axis = 1)
    return X
def fit(self, *_):
    return self

When I run fit() and transform() separately, it all works fine:

split_trans = StringSplitTransformer(cols = ['Cabin'])
split_trans.fit(df)
split_trans.transform(df, '/')

But when I run fit_transform() it give me an error:

split_trans.fit_transform(X_train, '/')

TypeError: transform() missing 1 required positional argument: 'delim'

In my transform() function if I don't have the delim parameter and instead just provide the delimiter then fit_transform() works. I don't understand why it does that.

Upvotes: 1

Views: 549

Answers (1)

Sanjar Adilov
Sanjar Adilov

Reputation: 1099

fit should accept at least two arguments, positional X and optional y=None. When you call fit_transform, your transformer assigns y='\' and misses delim. Well, I would rather make delim an attribute:

class StringSplitTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, delim, cols=None):
        self.delim = delim
        self.cols = cols

    def fit(self, df, y=None):
        return self

    def transform(self, df):
        X = df.copy()
        for col in self.cols:
            X = pd.concat([X, X[col].str.split(self.delim, expand=True)],
                          axis=1)
        return X

Upvotes: 3

Related Questions