Reputation: 23
I am trying to replicate this solution Python pandas: how to run multiple univariate regression by group but using sklearn linear regression instead of statsmodels.
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
'y': np.random.randn(20),
'x1': np.random.randn(20),
'x2': np.random.randn(20),
'grp': ['a', 'b'] * 10})
def ols_res(x, y):
return pd.Series(LinearRegression.fit(x,y).predict(x))
results = df.groupby('grp').apply(lambda x : x[['x1', 'x2']].apply(ols_res, y=x['y']))
print(results)
I get:
TypeError: ("fit() missing 1 required positional argument: 'y'", 'occurred at index x1')
The results should be the same as the article I linked, which is:
x1 x2
grp
a 0 -0.102766 -0.205196
1 -0.073282 -0.102290
2 0.023832 0.033228
3 0.059369 -0.017519
4 0.003281 -0.077150
... ...
b 5 0.072874 -0.002919
6 0.180362 0.000502
7 0.005274 0.050313
8 -0.065506 -0.005163
9 0.003419 -0.013829
Upvotes: 2
Views: 1025
Reputation: 19885
There are a two minor problems with your code:
You don't instantiate a LinearRegression
object, so your code actually tries to call the unbound fit
method of the LinearRegression
class.
Even if you fix this, the LinearRegression
instance will be unable to perform fit
and transform
because it expects a 2D array and gets a 1D one. Accordingly, you also need to reshape the array contained in each Series
.
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
'y': np.random.randn(20),
'x1': np.random.randn(20),
'x2': np.random.randn(20),
'grp': ['a', 'b'] * 10})
def ols_res(x, y):
x_2d = x.values.reshape(len(x), -1)
return pd.Series(LinearRegression().fit(x_2d, y).predict(x_2d))
results = df.groupby('grp').apply(lambda df: df[['x1', 'x2']].apply(ols_res, y=df['y']))
print(results)
Output:
x1 x2
grp
a 0 -0.126680 0.137907
1 -0.441300 -0.595972
2 -0.285903 -0.385033
3 -0.252434 0.560938
4 -0.046632 -0.718514
5 -0.267396 -0.693155
6 -0.364425 -0.476643
7 -0.221493 -0.779082
8 -0.203781 0.722860
9 -0.106912 -0.090262
b 0 -0.015384 0.092137
1 0.478447 0.032881
2 0.366102 0.059832
3 -0.055907 0.055388
4 -0.221876 0.013941
5 -0.054299 0.048263
6 0.043979 0.024594
7 -0.307831 0.059972
8 -0.226570 -0.024809
9 0.394460 0.038921
Upvotes: 1