Avanish Mishra
Avanish Mishra

Reputation: 185

Remove highly correlated columns from a pandas dataframe

I have a dataframe name data whose correlation matrix I computed by using

corr = data.corr()

If the correlation between two columns is greater than 0.75, I want to remove one of them from dataframe data. I tried some option

raw =corr[(corr.abs()>0.75) & (corr.abs() < 1.0)]

but it did not help; I need column number from raw for which value is nonzero. Basically some python equivalent of the following R command (which uses the function findCorrelation).

{hc=findCorrelation(corr,cutoff = 0.75)

hc = sort(hc)

data <- data[,-c(hc)]}

If anyone can help me to get command similar to above mention R command in python pandas, that would be helpful.

Upvotes: 8

Views: 20964

Answers (3)

I have used answer by cottontail a lot. Recently I have seen issue in one of my codes and I have traced back the issue to this code which I believe it has some kind of randomness in it. I have used to remove collinearity among 400 features. Every time the output was a bit different. Once I have replaced the code with mine (it works but Im not sure my solution is the best) the issue was gone. It would be great if the randomness fixed. To me it is like if A and B are correlated then either A or B stays.

Upvotes: 0


Reputation: 23071

The answer by piRSquared works great but it removes all columns with correlation above the cutoff, which overdoes it compared to how findCorrelation behaves in R. Assuming these are features in a machine learning model, we need to drop columns just enough so that the pairwise correlation coefficients among the columns are less than some cutoff point (perhaps multicollinearity is a problem etc.). Dropping too many would potentially hurt whatever model that is built on this data. As Sergey Bushmanov mentions in a comment, between columns C and H, only one should be dropped.

Python implementation of R's caret::findCorrelation

R's caret::findCorrelation looks at the mean absolute correlation of each variable and removes the variable with the largest mean absolute correlation for each pair of columns. The following function (named findCorrelation) implements the very same logic.

Depending on the size of the correlation matrix, caret::findCorrelation calls one of two functions: the fully vectorized findCorrelation_fast or the loopy findCorrelation_exact (you can call either regardless of dataframe size by using the exact= argument appropriately). The function below does the very same.

The only behavior different from caret::findCorrelation is that it returns a list of column names whereas caret::findCorrelation returns the index of the columns. I believe it's more natural to return column names which we can pass to drop later on.

import numpy as np
import pandas as pd

def findCorrelation(corr, cutoff=0.9, exact=None):
    This function is the Python implementation of the R function 
    Relies on numpy and pandas, so must have them pre-installed.
    It searches through a correlation matrix and returns a list of column names 
    to remove to reduce pairwise correlations.
    For the documentation of the R function, see 
    and for the source code of `findCorrelation()`, see

    corr: pandas dataframe.
        A correlation matrix as a pandas dataframe.
    cutoff: float, default: 0.9.
        A numeric value for the pairwise absolute correlation cutoff
    exact: bool, default: None
        A boolean value that determines whether the average correlations be 
        recomputed at each step
    list of column names
    R1 = pd.DataFrame({
        'x1': [1.0, 0.86, 0.56, 0.32, 0.85],
        'x2': [0.86, 1.0, 0.01, 0.74, 0.32],
        'x3': [0.56, 0.01, 1.0, 0.65, 0.91],
        'x4': [0.32, 0.74, 0.65, 1.0, 0.36],
        'x5': [0.85, 0.32, 0.91, 0.36, 1.0]
    }, index=['x1', 'x2', 'x3', 'x4', 'x5'])

    findCorrelation(R1, cutoff=0.6, exact=False)  # ['x4', 'x5', 'x1', 'x3']
    findCorrelation(R1, cutoff=0.6, exact=True)   # ['x1', 'x5', 'x4'] 
    def _findCorrelation_fast(corr, avg, cutoff):

        combsAboveCutoff = corr.where(lambda x: (np.tril(x)==0) & (x > cutoff)).stack().index

        rowsToCheck = combsAboveCutoff.get_level_values(0)
        colsToCheck = combsAboveCutoff.get_level_values(1)

        msk = avg[colsToCheck] > avg[rowsToCheck].values
        deletecol = pd.unique(np.r_[colsToCheck[msk], rowsToCheck[~msk]]).tolist()

        return deletecol

    def _findCorrelation_exact(corr, avg, cutoff):

        x = corr.loc[(*[avg.sort_values(ascending=False).index]*2,)]

        if (x.dtypes.values[:, None] == ['int64', 'int32', 'int16', 'int8']).any():
            x = x.astype(float)

        x.values[(*[np.arange(len(x))]*2,)] = np.nan

        deletecol = []
        for ix, i in enumerate(x.columns[:-1]):
            for j in x.columns[ix+1:]:
                if x.loc[i, j] > cutoff:
                    if x[i].mean() > x[j].mean():
                        x.loc[i] = x[i] = np.nan
                        x.loc[j] = x[j] = np.nan
        return deletecol

    if not np.allclose(corr, corr.T) or any(corr.columns!=corr.index):
        raise ValueError("correlation matrix is not symmetric.")
    acorr = corr.abs()
    avg = acorr.mean()
    if exact or exact is None and corr.shape[1]<100:
        return _findCorrelation_exact(acorr, avg, cutoff)
        return _findCorrelation_fast(acorr, avg, cutoff)

You can call findCorrelation to find the columns to drop and call drop() on the dataframe to drop those columns (exactly how you would use this function is R).

Using piRSquared's setup, it returns the following output.

corr = df.corr()
hc = findCorrelation(corr, cutoff=0.5)
trimmed_df = df.drop(columns=hc)


Upvotes: 6


Reputation: 294228

Use np.eye to ignore the diagonal values and find all columns that have some value whose absolute value is greater than the threshold. Use the logical negation as a mask for the index and columns.

Your example

m = ~(corr.mask(np.eye(len(corr), dtype=bool)).abs() > 0.75).any()

raw = corr.loc[m, m]

Working example

data = pd.DataFrame(
    np.random.randint(10, size=(10, 10)),

   A  B  C  D  E  F  G  H  I  J
0  0  2  7  3  8  7  0  6  8  6
1  0  2  0  4  9  7  3  2  4  3
2  3  6  7  7  4  5  3  7  5  9
3  8  7  6  4  7  6  2  6  6  5
4  2  8  7  5  8  4  7  6  1  5
5  2  8  2  4  7  6  9  4  2  4
6  6  3  8  3  9  8  0  4  3  0
7  4  1  5  8  6  0  8  7  4  6
8  3  5  8  5  1  5  1  4  3  9
9  5  5  7  0  3  2  5  8  8  9

corr = data.corr()

      A     B     C     D     E     F     G     H     I     J
A  1.00  0.22  0.42 -0.12 -0.17 -0.16 -0.11  0.35  0.13 -0.06
B  0.22  1.00  0.10 -0.08 -0.18  0.07  0.33  0.12 -0.34  0.17
C  0.42  0.10  1.00 -0.08 -0.41 -0.12 -0.42  0.55  0.20  0.34
D -0.12 -0.08 -0.08  1.00 -0.05 -0.29  0.27  0.02 -0.45  0.11
E -0.17 -0.18 -0.41 -0.05  1.00  0.47  0.00 -0.38 -0.19 -0.86
F -0.16  0.07 -0.12 -0.29  0.47  1.00 -0.62 -0.67 -0.08 -0.54
G -0.11  0.33 -0.42  0.27  0.00 -0.62  1.00  0.22 -0.40  0.07
H  0.35  0.12  0.55  0.02 -0.38 -0.67  0.22  1.00  0.50  0.59
I  0.13 -0.34  0.20 -0.45 -0.19 -0.08 -0.40  0.50  1.00  0.40
J -0.06  0.17  0.34  0.11 -0.86 -0.54  0.07  0.59  0.40  1.00

m = ~(corr.mask(np.eye(len(corr), dtype=bool)).abs() > 0.5).any()

A     True
B     True
C    False
D     True
E    False
F    False
G    False
H    False
I     True
J    False
dtype: bool

raw = corr.loc[m, m]

      A     B     D     I
A  1.00  0.22 -0.12  0.13
B  0.22  1.00 -0.08 -0.34
D -0.12 -0.08  1.00 -0.45
I  0.13 -0.34 -0.45  1.00

Upvotes: 15

Related Questions