CiaranWelsh
CiaranWelsh

Reputation: 7689

Can I colour a seaborn distplot by values in another variable?

I would like to colour a seaborn.distplot by another variable in my dataframe.

Here's an abstracted example:

import numpy
import pandas
import seaborn
A = numpy.random.choice(100, size=1000)
B =  numpy.random.choice(1000, size=1000)
df = pandas.DataFrame([A, B], index=['A', 'B']).transpose()
df = df.sort_values(by='B')
plt.figure()
seaborn.distplot(df['A'], bins=50)
plt.show()

Which produces: enter image description here

Is it possible to now colour this plot based on the values indf['B']?


edit

To clarify, Lets say that A is the distribution of peoples age and B is their weight. I'd like a gradient of colours such that bars of the histogram are coloured (say) green if older people are also heavy. Note that I'm not expecting a nice spectrum - the 'interesting data' may well appear in the middle of the plot. For me interesting data are A's which have a low B.

I hope that clears things up.

Upvotes: 1

Views: 3759

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339705

I would suggest to keep data aggregation and visualization separate. This mostly allows to separate a problem into pieces for which soltions can be found more easily.

In this case I guess the idea would be to create from the input data a table like this

              0     weight  density
age                                
(0, 10]   140.0  54.388877   0.0140
(10, 20]  269.0  71.422041   0.0269
(20, 30]  273.0  78.842196   0.0273
(30, 40]  188.0  79.433658   0.0188
(40, 50]   92.0  76.108056   0.0092
(50, 60]   28.0  69.800159   0.0028
(60, 70]    7.0  61.524235   0.0007
(70, 80]    3.0  52.942435   0.0003
(80, 90]    NaN        NaN      NaN

where we have the number of people, their mean weight and the density as columns and the binned ages as rows.

Such a table can then be easily plotted.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(46)
import seaborn as sns

# create data
a = np.random.rayleigh(20, size=1000)
b = 80*np.sin(np.sqrt((a+1)/20.*np.pi/2.))
df = pd.DataFrame({"age" : a,  "weight" : b})

# calculate age density and mean weight
bins = np.arange(0,100,10)
groups = df.groupby([pd.cut(df.age, bins),'weight' ])
df2 = groups.size().reset_index(["age","weight"])

df3 = df2.groupby("age")[0].sum()
df4 = df2.groupby("age")["weight"].mean()

df6 = pd.concat([df3,df4], axis=1)
df6["density"] = df6[0]/np.sum(df6[0].fillna(0).values*np.diff(bins))

# prepare colors
norm=plt.Normalize(np.nanmin(df6["weight"].values), 
                   np.nanmax(df6["weight"].values))
colors = plt.cm.plasma(norm(df6["weight"].fillna(0).values))

# create figure and axes
fig, ax = plt.subplots()
# bar plot
ax.bar(bins[:-1],df6.fillna(0)["density"], width=10, color=colors, align="edge")
# KDE plot
sns.kdeplot(df["age"], ax=ax, color="k", lw=2)

#create colorbar
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])
fig.colorbar(sm, ax=ax, label="weight")

#annotate axes
ax.set_ylabel("density")
ax.set_xlabel("age")
plt.show()

enter image description here

Upvotes: 4

Related Questions