Johnny
Johnny

Reputation: 869

Selecting specific features based on correlation values

I am using the Housing train.csv data from Kaggle to run a prediction.
https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data?select=train.csv

I am trying to generate a correlation and only keep the features that correlate with SalePrice from 0.5 to 0.9. I tried to use this function to fileter some of it, but I am removing the correlation values that are above .9 only. How would I update this function to only keep those specific features that I need to generate a correlation heat map?

data = train
corr = data.corr()
columns = np.full((corr.shape[0],), True, dtype=bool)
for i in range(corr.shape[0]):
    for j in range(i+1, corr.shape[0]):
        if corr.iloc[i,j] >= 0.9:
            if columns[j]:
                columns[j] = False
selected_columns = data.columns[columns]
data = data[selected_columns]

Upvotes: 0

Views: 540

Answers (1)

snehil
snehil

Reputation: 626

import pandas as pd

data = pd.read_csv('train.csv')
col = data.columns

c  = [i for i in col if data[i].dtypes=='int64' or data[i].dtypes=='float64']   # dropping columns as dtype == object
main_col = ['SalePrice']        # column with which we have to compare correlation

corr_saleprice = data.corr().filter(main_col).drop(main_col)    

c1 =(corr_saleprice['SalePrice']>=0.5) & (corr_saleprice['SalePrice']<=0.9)
c2 =(corr_saleprice['SalePrice']>=-0.9) & (corr_saleprice['SalePrice']<=-0.5)

req_index= list(corr_saleprice[c1 | c2].index)   # selecting column with given criteria

#req_index.append('SalePrice')      #if you want SalePrice column in your final dataframe too , uncomment this line

data = data[req_index]  

data

Also using for loops is not so efficient, a direct implementation is favorable. I hope this is what you want!

For generating heatmap , you can use following code:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

a  =data.corr()
mask = np.triu(np.ones_like(a, dtype=np.bool))
plt.figure(figsize=(10,10))
_ = sns.heatmap(a,cmap=sns.diverging_palette(250, 20, n=250),square=True,mask=mask,annot=True,center=0.5)

Upvotes: 1

Related Questions