user12916766
user12916766

Reputation:

Heatmap correlation using values of column?

Suppose I have the following data of repeat observations for US states with some value of interest:

US_State Value
Alabama  1
Alabama  10
Alabama  9
Michigan 8
Michigan 9
Michigan 2
...

How can I generate pairwise correlations for Value between all the US_State combinations? I've tried a few different things (pivot, groupby, and more), but I can't seem to wrap my head around the proper approach.

The ideal output would look like:

          Alabama   Michigan    ...
Alabama      1          0.5
Michigan     0.5        1
...

Upvotes: 1

Views: 2645

Answers (2)

ELinda
ELinda

Reputation: 2821

Pandas DataFrame has a built-in correlation matrix function. You will somehow need to get your data into a DataFrame (takes numpy objects, plain dict (shown), etc).

from pandas import DataFrame

data = {'AL': [1,10,9],
        'MI': [8,9,2],
        'CO': [11,5,17]
        }

df = DataFrame(data)

corrMatrix = df.corr()
print(corrMatrix)

# optional heatmap
import seaborn as sn
sn.heatmap(corrMatrix, annot=True, cmap='coolwarm')

          AL        MI        CO
AL  1.000000 -0.285578 -0.101361
MI -0.285578  1.000000 -0.924473
CO -0.101361 -0.924473  1.000000

Upvotes: 1

Richard Nemeth
Richard Nemeth

Reputation: 1864

There is a way utilising Pandas to its extents, but this is only under the assumption that each state in the input dataset has the same number of observations, otherwise correlation coefficient does not really make sense and the results will become a bit funky.

import pandas as pd

df = pd.DataFrame()
df['US_State'] = ["Alabama", "Alabama", "Alabama", "Michigan", "Michigan", "Michigan", "Oregon", "Oregon", "Oregon"]
df['Value'] = [1, 10, 9, 8, 9, 2, 6, 1, 2]

pd.DataFrame(df.groupby("US_State")['Value'].apply(lambda x: list(x))).T.apply(lambda x: pd.Series(*x), axis=0).corr()

which results into

US_State   Alabama  Michigan    Oregon
US_State                              
Alabama   1.000000 -0.285578 -0.996078
Michigan -0.285578  1.000000  0.199667
Oregon   -0.996078  0.199667  1.000000

What the code basically does is it collects the data for each state into a single cell as a list, transposes the dataframe to make the states columns and then expands the collected cell of list data into dataframe rows for each state. Then you can just call the standard corr() method of pandas dataframe.

Upvotes: 1

Related Questions