Reputation: 11
I am struggling with this function.
What I need this function to do is plot X number of seaborn displots (in a separate figure) for every unique value in a column.
In my example below I am using the iris dataset converted to a dataframe with three additional columns added: city
,color
,period
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib import rcParams
from sklearn.datasets import load_iris
iris=load_iris()
df=pd.DataFrame(data= np.c_[iris['data'], iris['target']],columns= iris['feature_names'] + ['target'])
df['city']=np.random.choice(['New York','Paris','London'],size=len(df))
df['period']=np.random.choice(['before','after'],size=len(df))
df['color']=np.random.choice(['red','black','blue'],size=len(df))
unique_vals = df['period'].unique()
targets = [df.loc[df['period'] == val] for val in unique_vals]
for target in targets:
sns.distplot(target[[r'petal width (cm)']], hist=False,label='shouldbedynamic')
sns.distplot(target[[r'sepal width (cm)']], hist=False,label='shouldbedynamic')
plt.legend()
plt.show()
So far this code is able to plot two measures I defined split by X variable (in this case period
).
Let's say now I want to see the exact same output (same measures and split by period
plotted) but for every value in the city
column a new plot/figure would be generated. I can do this manually of course by filtering one value at a time, but in the case that city
has 50 unique values I would need a function to iterate through and plot 50 separate distributions instead.
In addition to this I have two smaller questions:
label
option to dynamically generate a legend showing what the colored lines represent?EDIT: Just want to make clear what I want is to be able to plot the distributions in separate plots (not in the same).So if city
has 50 distinct values, the same figure in the picture would be generated for each city's data: New York, Paris, London, etc.
Upvotes: 1
Views: 3039
Reputation: 193
Do you mean you want 50 separate plots (one for each city, and split by period)? Or you want 50 distribution in the same plot (one for each city, not split by period)?
Getting dynamic labels is easy, just use groupby
rather than unique
for period, group in df.groupby('period'):
sns.distplot(group[[r'petal width (cm)']], hist=False, label=f'petal: {period}')
sns.distplot(group[[r'sepal width (cm)']], hist=False, label=f'sepal: {period}')
plt.legend()
You can also set the color
parameter of sns.distplot
to select whatever colours you want, but with 50 plots you might want to look at colormaps
EDIT:
Now it's a bit more clear what you want, you could try something like this
def plot_city(city_name, data):
""" generate plot for one city """
measures = {
'petal width (cm)': 'tab:orange',
'sepal width (cm)': 'tab:blue',
}
line_styles = {
'before': '--',
'after': '-',
}
fig, ax = plt.subplots(figsize=(12, 9))
for measure, colour in measures.items():
for period, group in data.groupby('period'):
sns.distplot(
ax=ax,
a=group[measure],
hist=False,
label=f'petal: {period}',
color=colour,
kde_kws={'linestyle':line_styles[period]}
)
ax.set_title(city_name, fontsize=24)
ax.set_xlabel('width (cm)', fontsize=18)
plt.legend(fontsize=18)
return fig
for city_name, data in df.groupby('city'):
fig = plot_city(city_name, data)
fig.savefig(f'./{city_name}.png', bbox_inches='tight')
plt.show()
Upvotes: 1