Reputation: 5031
Currently I have the following code
import matplotlib.pyplot as plt
import pandas as pd
from pandas.plotting import scatter_matrix
df= pd.read_csv(file, sep=',')
colors = list('r' if i==1 else 'b' for i in df['class']) # class is either 1 or 0
plt.figure()
scatter_matrix(df, color=colors )
plt.show()
It shows the following output
But in this plot on diagonals, instead of simple histogram I want to show stacked histogram like the following such that for class '1' it is red and for '0' it is blue
Please guide me how can I do this ?
Upvotes: 1
Views: 2476
Reputation: 339725
The use of seaborn is probably highly beneficial for plotting a scatter matrix kind of plot. However, I do not know how to plot a stacked histogram easily into the diagonal of a PairGrid
in seaborn.
As the question anyways asks for matplotlib, the following is a solution using pandas and matplotlib. Unfortunately it will require to do a lot of stuff by hand. The following would be an example (note that seaborn is only imported to get some data since the question did not provide any).
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# seaborn import just needed to get some data
import seaborn as sns
df = sns.load_dataset("iris")
n_hist = 10
category = "species"
columns = ["sepal_length","sepal_width","petal_length","petal_width"]
mi = df[columns].values.min()
ma = df[columns].values.max()
hist_bins = np.linspace(mi, ma, n_hist)
fig, axes = plt.subplots(nrows=len(columns), ncols=len(columns),
sharex="col")
for i,row in enumerate(columns):
for j,col in enumerate(columns):
ax= axes[i,j]
if i == j:
# diagonal
mi = df[col].values.min()
ma = df[col].values.max()
hist_bins = np.linspace(mi, ma, n_hist)
def hist(x):
h, e = np.histogram(x.dropna()[col], bins=hist_bins)
return pd.Series(h, e[:-1])
b = df[[col,category]].groupby(category).apply(hist).T
values = np.cumsum(b.values, axis=1)
for k in range(len(b.columns)):
if k == 0:
ax.bar(b.index, values[:,k], width=np.diff(hist_bins)[0])
else:
ax.bar(b.index, values[:,k], width=np.diff(hist_bins)[0],
bottom=values[:,k-1])
else:
# offdiagonal
for (n,cat) in df.groupby(category):
ax.scatter(cat[col],cat[row], s = 5,label=n, )
ax.set_xlabel(col)
ax.set_ylabel(row)
#ax.legend()
plt.tight_layout()
plt.show()
Upvotes: 2
Reputation: 323376
Sample code
import seaborn as sns
sns.set(style="ticks")
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Upvotes: 1