chico0913
chico0913

Reputation: 647

how to draw multiple seaborn `distplot` in a single window?

I am trying to draw multiple seaborn distplot in a single window. I know how to generate a density plot for a single list of data, as shown in my code below (make_density function). However, I am not sure how to draw multiple seaborn distplots under a single window. Suppose that my list stat_list contains 6 lists as its element, where I want to draw a single distplot from each of these 6 lists under stat_list. How can I draw the 6 displots under a same window, where 3 plots are displayed in each row (so that my output would have 2 rows of 3 plots)?

Thank you,


# function to plot the histogram for a single list.
def make_density(stat_list, color, x_label, y_label):
    
    # Plot formatting
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    # Draw the histogram and fit a density plot.
    sns.distplot(stat_list, hist = True, kde = True,
                 kde_kws = {'linewidth': 2}, color=color)
    
    # get the y-coordinates of the points of the density curve.
    dens_list = sns.distplot(stat_list, hist = False, kde = False,
             kde_kws = {'linewidth': 2}, color = color).get_lines()[0].get_data()[1].tolist()
        
    # find the maximum y-coordinates of the density curve.            
    max_dens_index = dens_list.index(max(dens_list))
    
    # find the mode of the density plot.
    mode_x = sns.distplot(stat_list, hist = False, kde = False,
             kde_kws = {'linewidth': 2}, color = color).get_lines()[0].get_data()[0].tolist()[max_dens_index]
    
    # draw a vertical line at the mode of the histogram.
    plt.axvline(mode_x, color='blue', linestyle='dashed', linewidth=1.5)
    plt.text(mode_x * 1.05, 0.16, 'Mode: {:.4f}'.format(mode_x))

# `stat_list` is a list of 6 lists
# I want to draw histogram and density plot of 
# each of these 6 lists contained in `stat_list` in a single window,
# where each row containing the histograms and densities of the 3 plots
# so in my example, there would be 2 rows of 3 columns of plots (2 x 3 =6).
stat_list = [[0.3,0.5,0.7,0.3,0.5],[0.2,0.1,0.9,0.7,0.4],[0.9,0.8,0.7,0.6,0.5]
          [0.2,0.6,0.75,0.87,0.91],[0.2,0.3,0.8,0.9,0.3],[0.2,0.3,0.8,0.87,0.92]]

Upvotes: 1

Views: 5451

Answers (2)

Paul H
Paul H

Reputation: 68116

I would use seaborn's FacetGrid class for this.

Simple version:

import seaborn
seaborn.set(style='ticks', context='paper')

axgrid = (
    seaborn.load_dataset('titanic')
        .pipe(seaborn.FacetGrid, hue='deck', col='deck', col_wrap=3, sharey=False)
        .map(seaborn.distplot, 'fare')
)

Or with some modifications to your function:

from matplotlib import pyplot
import seaborn
seaborn.set(style='ticks', context='paper')


# function to plot the histogram for a single list.
def make_density(stat, color=None, x_label=None, y_label=None, ax=None, label=None):
   
    if not ax:
        ax = pyplot.gca()
    # Draw the histogram and fit a density plot.
    seaborn.distplot(stat, hist=True, kde=True,
                     kde_kws={'linewidth': 2}, color=color, ax=ax)

    # get the y-coordinates of the points of the density curve.
    dens_list = ax.get_lines()[0].get_data()[1]

    # find the maximum y-coordinates of the density curve.
    max_dens_index = dens_list.argmax()

    # find the mode of the density plot.
    mode_x = ax.get_lines()[0].get_data()[0][max_dens_index]

    # draw a vertical line at the mode of the histogram.
    ax.axvline(mode_x, color=color, linestyle='dashed', linewidth=1.5)
    ymax = ax.get_ylim()[1]
    ax.text(mode_x * 1.1, ymax * 0.16, 'Mode: {:.4f}'.format(mode_x))

    # Plot formatting
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)


axgrid = (
    seaborn.load_dataset('titanic')
        .pipe(seaborn.FacetGrid, hue='deck', col='deck', col_wrap=3, sharey=False)
        .map(make_density, 'fare')
)

enter image description here

Upvotes: 3

JohanC
JohanC

Reputation: 80279

You can create a grid of subplots with fig, axes = plt.subplots(...). Then you can provide each 'ax' of the returned 'axes' as the ax= parameter of sns.distplot(). Note that you'll need the same ax to set the labels, plt.xlabel() will only change one of the subplots.

Calling sns.distplot three times is not recommended. sns.distplot will add more and more information to the same ax. Also note that you can use numpy functions such as argmax() to efficiently find the maximum without the need to convert to Python lists (which are quite slow when there is a lot of data).

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# function to plot the histogram for a single list.
def make_density(stat, color, x_label, y_label, ax):
    # Draw the histogram and fit a density plot.
    sns.distplot(stat, hist=True, kde=True,
                 kde_kws={'linewidth': 2}, color=color, ax=ax)

    # get the y-coordinates of the points of the density curve.
    dens_list = ax.get_lines()[0].get_data()[1]

    # find the maximum y-coordinates of the density curve.
    max_dens_index = dens_list.argmax()

    # find the mode of the density plot.
    mode_x = ax.get_lines()[0].get_data()[0][max_dens_index]

    # draw a vertical line at the mode of the histogram.
    ax.axvline(mode_x, color='blue', linestyle='dashed', linewidth=1.5)
    ax.text(mode_x * 1.05, 0.16, 'Mode: {:.4f}'.format(mode_x))

    # Plot formatting
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

stat_list = [[0.3, 0.5, 0.7, 0.3, 0.5], [0.2, 0.1, 0.9, 0.7, 0.4], [0.9, 0.8, 0.7, 0.6, 0.5],
             [0.2, 0.6, 0.75, 0.87, 0.91], [0.2, 0.3, 0.8, 0.9, 0.3], [0.2, 0.3, 0.8, 0.87, 0.92]]
num_subplots = len(stat_list)
ncols = 3
nrows = (num_subplots + ncols - 1) // ncols
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 6, nrows * 5))
colors = plt.cm.tab10.colors
for ax, stat, color in zip(np.ravel(axes), stat_list, colors):
    make_density(stat, color, 'x_label', 'y_label', ax)
for ax in np.ravel(axes)[num_subplots:]:  # remove possible empty subplots at the end
    ax.remove()
plt.show()

resulting plot

PS: Instead of distplot also histplot (new in Seaborn 0.11) could be used. This should give a nicer plot, especially when the data are few and/or discrete.

sns.histplot(stat, kde=True, line_kws={'linewidth': 2}, color=color, ax=ax)

histplot

Upvotes: 2

Related Questions