Kris
Kris

Reputation: 11

Creating function to plot multiple distribution plots for every unique value in column

I am struggling with this function. What I need this function to do is plot X number of seaborn displots (in a separate figure) for every unique value in a column. In my example below I am using the iris dataset converted to a dataframe with three additional columns added: city,color,period

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib import rcParams
from sklearn.datasets import load_iris

iris=load_iris()
df=pd.DataFrame(data= np.c_[iris['data'], iris['target']],columns= iris['feature_names'] + ['target'])

df['city']=np.random.choice(['New York','Paris','London'],size=len(df))
df['period']=np.random.choice(['before','after'],size=len(df))
df['color']=np.random.choice(['red','black','blue'],size=len(df))

unique_vals = df['period'].unique()
targets = [df.loc[df['period'] == val] for val in unique_vals]
for target in targets:
    sns.distplot(target[[r'petal width (cm)']], hist=False,label='shouldbedynamic')
    sns.distplot(target[[r'sepal width (cm)']], hist=False,label='shouldbedynamic')
    plt.legend()

plt.show()

Snippet of the above output

So far this code is able to plot two measures I defined split by X variable (in this case period). Let's say now I want to see the exact same output (same measures and split by period plotted) but for every value in the city column a new plot/figure would be generated. I can do this manually of course by filtering one value at a time, but in the case that city has 50 unique values I would need a function to iterate through and plot 50 separate distributions instead.

In addition to this I have two smaller questions:

EDIT: Just want to make clear what I want is to be able to plot the distributions in separate plots (not in the same).So if city has 50 distinct values, the same figure in the picture would be generated for each city's data: New York, Paris, London, etc.

Upvotes: 1

Views: 3039

Answers (1)

VersBersch
VersBersch

Reputation: 193

Do you mean you want 50 separate plots (one for each city, and split by period)? Or you want 50 distribution in the same plot (one for each city, not split by period)?

Getting dynamic labels is easy, just use groupby rather than unique

for period, group in df.groupby('period'):
    sns.distplot(group[[r'petal width (cm)']], hist=False, label=f'petal: {period}')
    sns.distplot(group[[r'sepal width (cm)']], hist=False, label=f'sepal: {period}')

plt.legend() 

You can also set the color parameter of sns.distplot to select whatever colours you want, but with 50 plots you might want to look at colormaps

EDIT:

Now it's a bit more clear what you want, you could try something like this

def plot_city(city_name, data):
    """ generate plot for one city """

   measures = {
        'petal width (cm)': 'tab:orange',
        'sepal width (cm)': 'tab:blue',
    }

    line_styles = {
        'before': '--',
        'after': '-',
    }

    fig, ax = plt.subplots(figsize=(12, 9))

    for measure, colour in measures.items():
        for period, group in data.groupby('period'):
            sns.distplot(
                ax=ax,
                a=group[measure], 
                hist=False, 
                label=f'petal: {period}', 
                color=colour,
                kde_kws={'linestyle':line_styles[period]}
            )

    ax.set_title(city_name, fontsize=24)
    ax.set_xlabel('width (cm)', fontsize=18)
    plt.legend(fontsize=18) 

    return fig


for city_name, data in df.groupby('city'):
    fig = plot_city(city_name, data)
    fig.savefig(f'./{city_name}.png', bbox_inches='tight')
    plt.show()

Upvotes: 1

Related Questions