Sean
Sean

Reputation: 3385

Plotting two heat maps side by side in Matplotlib

I have a function that plots the heat map for the correlation matrix of a DataFrame. The function looks like this:

def corr_heatmap(data):
    columns = data.columns
    corr_matrix = data.corr()

    fig, ax = plt.subplots(figsize=(7, 7))
    mat = ax.matshow(corr_matrix, cmap='coolwarm')

    ax.set_xticks(range(len(columns)))
    ax.set_yticks(range(len(columns)))
    ax.set_xticklabels(columns)
    ax.set_yticklabels(columns)
    plt.setp(ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
    plt.colorbar(mat, fraction=0.045, pad=0.05)
    fig.tight_layout()
    plt.show()

    return mat

and when run with a DataFrame outputs something like this:

enter image description here

What I want to do is plot two of these heat maps side by side, but I'm having some trouble doing so. What I've done so far is attempt to assign each heat map to an AxesImage object and use subplots to plot them.

mat1 = corr_heatmap(corr_mat1)
mat2 = corr_heatmap(corr_mat2)

fig = plt.figure(figsize=(15, 15))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax1.plot(ma1)
ax2.plot(ma2)

but this gives me the following error:

TypeError: float() argument must be a string or a number, not 'AxesImage'

Would anybody happen to know a way that I could plot two heat map images side by side? Thank you.

EDIT

In case anyone's wondering what the final code for what I wanted to do would look like:

def corr_heatmaps(data1, data2, method='pearson'):

    # Basic Configuration
    fig, axes = plt.subplots(ncols=2, figsize=(12, 12))
    ax1, ax2 = axes
    corr_matrix1 = data1.corr(method=method)
    corr_matrix2 = data2.corr(method=method)
    columns1 = corr_matrix1.columns
    columns2 = corr_matrix2.columns

    # Heat maps.
    im1 = ax1.matshow(corr_matrix1, cmap='coolwarm')
    im2 = ax2.matshow(corr_matrix2, cmap='coolwarm')

    # Formatting for heat map 1.
    ax1.set_xticks(range(len(columns1)))
    ax1.set_yticks(range(len(columns1)))
    ax1.set_xticklabels(columns1)
    ax1.set_yticklabels(columns1)
    ax1.set_title(data1.name, y=-0.1)
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
    plt.colorbar(im1, fraction=0.045, pad=0.05, ax=ax1)

    # Formatting for heat map 2.
    ax2.set_xticks(range(len(columns2)))
    ax2.set_yticks(range(len(columns2)))
    ax2.set_xticklabels(columns2)
    ax2.set_yticklabels(columns2)
    ax2.set_title(data2.name, y=-0.1)
    plt.setp(ax2.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
    plt.colorbar(im2, fraction=0.045, pad=0.05, ax=ax2)

    fig.tight_layout()

This could (when run with two Pandas DataFrames) outputs something along the following image:

enter image description here

Upvotes: 6

Views: 13603

Answers (2)

gmds
gmds

Reputation: 19885

What you need is the plt.subplots function. Instead of manually adding Axes objects to a Figure, you can initialise a Figure along with a number of Axes. Then, it is as simple as calling matshow on each Axes:

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt

df = pd.DataFrame(np.random.rand(10, 10))

fig, axes = plt.subplots(ncols=2, figsize=(8, 4))

ax1, ax2 = axes

im1 = ax1.matshow(df.corr())
im2 = ax2.matshow(df.corr())

fig.colorbar(im1, ax=ax1)
fig.colorbar(im2, ax=ax2)

enter image description here

You can perform all the other formatting later.

Upvotes: 8

Ranjeet
Ranjeet

Reputation: 382

Please, Follow the below example, change the plot to matshow, do axis customization as per your need.

import numpy as np 
import matplotlib.pyplot as plt 

def f(t): 
    return np.exp(-t) * np.cos(2*np.pi*t) 

t1 = np.arange(0.0, 3.0, 0.01) 

ax1 = plt.subplot(121) 
ax1.plot(t1, f(t1), 'k') 

ax2 = plt.subplot(122) 
ax2.plot(t1, f(t1), 'r') 
plt.show() 

Output:

Output

Upvotes: -1

Related Questions