puranjan
puranjan

Reputation: 373

Scikit-learn pipeline returns list of zeroes

I am not able to understand why I am getting this wrong pipeline output.

Pipeline code:

my_pipeline = Pipeline(steps=[ 
    ('imputer', SimpleImputer(strategy='median')),
    ('std_scaler', StandardScaler())
])

Real data:

real = [[0.02498, 0.0, 1.89, 0.0, 0.518, 6.54, 59.7, 6.2669, 1.0, 422.0, 15.9, 389.96, 8.65]]

The pipeline output that I want:

want = [[-0.44228927, -0.4898311 , -1.37640684, -0.27288841, -0.34321545, 0.36524574, -0.33092752,  1.20235683, -1.0016859 ,  0.05733231, -1.21003475,  0.38110555, -0.57309194]]

But after running the below code:

getting = my_pipeline.fit_transform(real)

I am getting:

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

Upvotes: 1

Views: 147

Answers (1)

seralouk
seralouk

Reputation: 33147

The problem

This is an expected behavior because you define the data as a list.

After the first step of the pipeline i.e. the SimpleImputer, the returned output is a numpy array with shape (1,13).

si = SimpleImputer()
si_out = si.fit_transform(real)

si_out.shape
# (1, 13)

The returned (1,13) array is the problem here. This is because the StandardScaler, removes the mean and divides by the std each column. Thus, it "sees" 13 columns and the final output is all 0s since the means have been removed.

sc = StandardScaler()
sc.fit_transform(si_out)

returns

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

The solution

It seems that you have only one variable/feature named real. Just reshape it before fitting.

import numpy as np

real = np.array([[0.02498, 0.0, 1.89, 0.0, 0.518, 6.54, 59.7, 6.2669, 1.0, 422.0, 15.9, 389.96, 8.65]]).reshape(-1,1)

my_pipeline = Pipeline(steps=[ 
    ('imputer', SimpleImputer(strategy='median')),
    ('std_scaler', StandardScaler())
])
my_pipeline.fit_transform(real)

array([[-0.48677709],
       [-0.4869504 ],
       [-0.47383804],
       [-0.4869504 ],
       [-0.48335664],
       [-0.44157747],
       [-0.07276633],
       [-0.44347217],
       [-0.48001264],
       [ 2.44078289],
       [-0.37664007],
       [ 2.21849716],
       [-0.4269388 ]])

Upvotes: 1

Related Questions