muazfaiz
muazfaiz

Reputation: 5031

Matplotlib stacked histogram using `scatter_matrix` on pandas dataframe

Currently I have the following code

import matplotlib.pyplot as plt
import pandas as pd
from pandas.plotting import scatter_matrix

df= pd.read_csv(file, sep=',')
colors = list('r' if i==1 else 'b' for i in df['class']) # class is either 1 or 0
plt.figure()
scatter_matrix(df, color=colors )
plt.show()

It shows the following output

enter image description here

But in this plot on diagonals, instead of simple histogram I want to show stacked histogram like the following such that for class '1' it is red and for '0' it is blue

enter image description here

Please guide me how can I do this ?

Upvotes: 1

Views: 2476

Answers (2)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339725

The use of seaborn is probably highly beneficial for plotting a scatter matrix kind of plot. However, I do not know how to plot a stacked histogram easily into the diagonal of a PairGrid in seaborn.
As the question anyways asks for matplotlib, the following is a solution using pandas and matplotlib. Unfortunately it will require to do a lot of stuff by hand. The following would be an example (note that seaborn is only imported to get some data since the question did not provide any).

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# seaborn import just needed to get some data
import seaborn as sns
df = sns.load_dataset("iris")


n_hist = 10
category = "species"
columns = ["sepal_length","sepal_width","petal_length","petal_width"]
mi = df[columns].values.min()
ma = df[columns].values.max()
hist_bins = np.linspace(mi, ma, n_hist)


fig, axes = plt.subplots(nrows=len(columns), ncols=len(columns), 
                         sharex="col")

for i,row in enumerate(columns):
    for j,col in enumerate(columns):
        ax= axes[i,j]
        if i == j:
            # diagonal
            mi = df[col].values.min()
            ma = df[col].values.max()
            hist_bins = np.linspace(mi, ma, n_hist)
            def hist(x):
                h, e = np.histogram(x.dropna()[col], bins=hist_bins)
                return pd.Series(h, e[:-1])
            b = df[[col,category]].groupby(category).apply(hist).T
            values = np.cumsum(b.values, axis=1)
            for k in range(len(b.columns)):
                if k == 0:
                    ax.bar(b.index, values[:,k], width=np.diff(hist_bins)[0])
                else:
                    ax.bar(b.index, values[:,k], width=np.diff(hist_bins)[0],
                           bottom=values[:,k-1])
        else:
            # offdiagonal
            for (n,cat) in df.groupby(category):
                ax.scatter(cat[col],cat[row], s = 5,label=n, )
        ax.set_xlabel(col)
        ax.set_ylabel(row)
        #ax.legend()
plt.tight_layout()
plt.show()

enter image description here

Upvotes: 2

BENY
BENY

Reputation: 323376

Sample code

import seaborn as sns
sns.set(style="ticks")
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")

enter image description here

Upvotes: 1

Related Questions