vestland
vestland

Reputation: 61104

How to retrieve all data from seaborn distribution plot with mutliple distributions?

The post Get data points from Seaborn distplot describes how you can get data elements using sns.distplot(x).get_lines()[0].get_data(), sns.distplot(x).patches and [h.get_height() for h in sns.distplot(x).patches]

But how can you do this if you've used multiple layers by plotting the data in a loop, such as:

Snippet 1

for var in list(df):
    print(var)
    distplot = sns.distplot(df[var])

Plot

enter image description here

Is there a way to retrieve the X and Y values for both linecharts and the bars?


Here's the whole setup for an easy copy&paste:

#%%
# imports
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pylab
pylab.rcParams['figure.figsize'] = (8, 4)
import seaborn as sns
from collections import OrderedDict

# Function to build synthetic data
def sample(rSeed, periodLength, colNames):

    np.random.seed(rSeed)
    date = pd.to_datetime("1st of Dec, 1999")   
    cols = OrderedDict()

    for col in colNames:
        cols[col] = np.random.normal(loc=0.0, scale=1.0, size=periodLength)
    dates = date+pd.to_timedelta(np.arange(periodLength), 'D')

    df = pd.DataFrame(cols, index = dates)
    return(df)

# Dataframe with synthetic data
df = sample(rSeed = 123, colNames = ['X1', 'X2'], periodLength = 50)

# sns.distplot with multiple layers
for var in list(df):
    myPlot = sns.distplot(df[var])

Here's what I've tried:

Y-values for histogram:

If I run:

barX = [h.get_height() for h in myPlot.patches]

Then I get the following list of lenght 11:

[0.046234272703757885,
 0.1387028181112736,
 0.346757045278184,
 0.25428849987066837,
 0.2542884998706682,
 0.11558568175939472,
 0.11875881712519201,
 0.3087729245254993,
 0.3087729245254993,
 0.28502116110046083,
 0.1662623439752689]

And this seems reasonable since there seems to be 6 values for the blue bars and 5 values for the red bars. But how do I tell which values belong to which variable?

Y-values for line:

This seems a bit easier than the histogram part since you can use myPlot.get_lines()[0].get_data() AND myPlot.get_lines()[1].get_data() to get:

Out[678]: 
(array([-4.54448949, -4.47612134, -4.40775319, -4.33938504, -4.27101689,
         ...
         3.65968859,  3.72805675,  3.7964249 ,  3.86479305,  3.9331612 ,
         4.00152935,  4.0698975 ,  4.13826565]),
 array([0.00042479, 0.00042363, 0.000473  , 0.00057404, 0.00073097,
        0.00095075, 0.00124272, 0.00161819, 0.00208994, 0.00267162,
        ...
        0.0033384 , 0.00252219, 0.00188591, 0.00139919, 0.00103544,
        0.00077219, 0.00059125, 0.00047871]))

myPlot.get_lines()[1].get_data()

Out[679]: 
(array([-3.68337423, -3.6256517 , -3.56792917, -3.51020664, -3.4524841 ,
        -3.39476157, -3.33703904, -3.27931651, -3.22159398, -3.16387145,
         ...
         3.24332952,  3.30105205,  3.35877458,  3.41649711,  3.47421965,
         3.53194218,  3.58966471,  3.64738724]),
 array([0.00035842, 0.00038018, 0.00044152, 0.00054508, 0.00069579,
        0.00090076, 0.00116922, 0.00151242, 0.0019436 , 0.00247792,
        ...
        0.00215912, 0.00163627, 0.00123281, 0.00092711, 0.00070127,
        0.00054097, 0.00043517, 0.00037599]))

But the whole thing still seems a bit cumbersome. So does anyone know of a more direct approach to perhaps retrieve all data to a dictionary or dataframe?

Upvotes: 4

Views: 7284

Answers (1)

Miguel Trejo
Miguel Trejo

Reputation: 6667

I was just getting the same need of retrieving data from a seaborn distribution plot, what worked for me was to call the method .findobj() on each iteration's graph. Then, one can notice that the matplotlib.lines.Line2D object has a get_data() method, this is similar as what you've mentioned before for myPlot.get_lines()[1].get_data().

Following your example code

data = []
for idx, var in enumerate(list(df)):
    myPlot = sns.distplot(df[var])
    
    # Fine Line2D objects
    lines2D = [obj for obj in myPlot.findobj() if str(type(obj)) == "<class 'matplotlib.lines.Line2D'>"]
    
    # Retrieving x, y data
    x, y = lines2D[idx].get_data()[0], lines2D[idx].get_data()[1]
    
    # Store as dataframe 
    data.append(pd.DataFrame({'x':x, 'y':y}))

Notice here that the data for the first sns.distplot plot is stored on the first index of lines2D and the data for the second sns.distplot is stored on the second index. I'm not really sure about why this happens this way, but if you were to consider more than two plots, then you will access each sns.distplot data by calling Lines2D on it's respective index.

Finally, to verify one can plot each distplot

plt.plot(data[0].x, data[0].y)

enter image description here

Upvotes: 3

Related Questions