Santi XGR
Santi XGR

Reputation: 307

Obtain feature importance from a mixed effects random forest

I am an R user running for the first time python3.7 64bit on Windows. I was trying to get permutation importance from a mixed effects random forest using PermutationImportance from package eli5. Dataset for reproducibility can be found here.

Fit:

merf = MERF(n_estimators= 500, max_iterations= 100)
np.random.seed(100)
merf.fit(X_train_merf, Z_train, clusters_train, y_train)

Feature importance:

perimp = PermutationImportance(merf, cv = None, refit = False, n_iter = 50).fit(X_train, Z_train, clusters_train, y_train)

The above code produces this error

TypeError: fit() takes from 3 to 4 positional arguments but 5 were given

But fit() contains only 4 arguments...

Is it possible to obtain feature importance at all from merf objects?

Upvotes: 2

Views: 1046

Answers (2)

jypucca
jypucca

Reputation: 121

I ran into the same issue, and from the source code I figured out that you can pull the feature importance from the randomforest model object, which is part of MERF.

mrf.fit(X_train, Z_train, clusters_train, y_train)
feat_importance=mrf.trained_fe_model.feature_importances_

Upvotes: 2

Jim Panse
Jim Panse

Reputation: 625

I don't know about the merf or eli5 module, but I can tell you why that exception is happening.

If you look into the documentation of the PermutationImportance Module and its API you can see the following definition of the fit() method:

    fit(X, y, groups=None, **fit_params)

Those two stars before the last parameter mean it's an keyword argument. So in fact this method can take 3 positional arguments and many keyword arguments. But that also means you need to name your fourth parameter(s). Inside the method you get a dictionary for this parameter and the method needs to know how to handle it.

Example:

def my_fit(X, **fit_params):
    print(fit_params)

my_fit("positional argument", x=1,y=2,z=3)               
>>> {'x': 1, 'y': 2, 'z': 3}

I don't use eli5 so I can't tell you what keywords to use or if it is possible to obtain feature importance from merf objects at all, but the error is fixed by just giving your last parameter a name like:

perimp = PermutationImportance(merf, cv = None, refit = False, n_iter = 50).fit(X_train, Z_train, clusters_train, y_train=y_train)

Hopefully the method knows what to do with a parameter named like this.

Upvotes: 2

Related Questions