Genius Kshitiz
Genius Kshitiz

Reputation: 66

Correlation Plot Color mismatch in Python using matplotlib

I am plotting correlation of data in python using matplotlib. The highly correlated data should be coloured dark red but it is coloured as yellow in my case. How to solve it?

My correlation data is this:

Screen shot

My code is like this:

def plot_corr(df, size=11):

"""\
Function plots a graphical correlation matrix for each pair of columns in the dataframe.

Input:
    df: pandas Dataframe
    size: vertical and horizontal size of the plot

Displays:
    matrix of correlation between columns. Blue-cyan-yellow-red-darkred => less to more correlated
                                           0 ------------------------> 1
                                           Expect a darkred line running from top left to bottom right
"""
corr = df.corr()    #data frame correlation function
fig, ax = plt.subplots(figsize=(size,size))
ax.matshow(corr)    # color code  the rectangles by correlation value
plt.xticks(range(len(corr.columns)), corr.columns)   # draw x tick marks
plt.yticks(range(len(corr.columns)), corr.columns)   # draw y tick marks

My output is like this:

Screen Shot

Upvotes: 0

Views: 1993

Answers (1)

gboffi
gboffi

Reputation: 25023

Matplotlib changed the default colormap from "jet" to "viridis", the first one maps the highest value to a dark red, the second to bright yellow.

The change was not a gratuitous one, the new colormap has a number of advantages over the old one (if you are interested in the reasons why, see e.g. this github issue.

One possibility is to leave the defaults undisturbed and possibly change the docstring in the part that describes the range of colours...

    """\
...
Displays:
    matrix of correlation between columns. Blue-teal-green-yellow => less to more correlated
                                           0 ------------------------> 1
                                           Expect a bright yellow line running from top left to bottom right.
    """

Another one is to explicitly mention the colormap that you want to use

def plot_corr(df, size=11):
    ...
    import matplotlib.cm as cm
    ...
    plt.matshow(corr, cmap=cm.jet)
    ...

A last possibility is to restore ALL the previous default of Matplotlib, either at the level of the calling program

plt.style.use('classic')

or at the level of the function

    ...
    with plt.style.context('default'):
        plt.matshow(corr)
        ...

Upvotes: 1

Related Questions