arilwan
arilwan

Reputation: 3993

matplotlib: plotting more than one figure at once

I am working with 3 pandas dataframes having the same column structures(number and type), only that the datasets are for different years.

I would like to plot the ECDF for each of the dataframes, but everytime I do this, I do it individually (lack python skills). So also, one of the figures (2018) is scaled differently on x-axis making it a bit difficult to compare. Here's how I do it.

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from empiricaldist import Cdf

    df1 = pd.read_csv('2016.csv')
    df2 = pd.read_csv('2017.csv')
    df3 = pd.read_csv('2018.csv')

    #some info about the dfs
    df1.columns.values
    array(['id', 'trip_id', 'distance', 'duration', 'speed', 'foot', 'bike',
       'car', 'bus', 'metro', 'mode'], dtype=object)

    modal_class = df1['mode']
    print(modal_class[:5])
    0         bus
    1         metro
    2         bike
    3         foot
    4         car

    def decorate_ecdf(title, x, y):
        plt.xlabel(x)
        plt.ylabel(y)
        plt.title(title)

    #plotting the ecdf for 2016 dataset
    for name, group in df1.groupby('mode'):
        Cdf.from_seq(group.speed).plot()    
    title, x, y = 'Speed distribution by travel mode (April 2016)','speed (m/s)', 'ECDF'
    decorate_ecdf(title,x,y)

    #plotting the ecdf for 2017 dataset
    for name, group in df2.groupby('mode'):
        Cdf.from_seq(group.speed).plot()    
    title, x, y = 'Speed distribution by travel mode (April 2017)','speed (m/s)', 'ECDF'
    decorate_ecdf(title,x,y)

    #plotting the ecdf for 2018 dataset
    for name, group in df3.groupby('mode'):
        Cdf.from_seq(group.speed).plot()    
    title, x, y = 'Speed distribution by travel mode (April 2018)','speed (m/s)', 'ECDF'
    decorate_ecdf(title,x,y)

Output:

enter image description here enter image description here enter image description here

I am pretty sure this isn't the pythonist way of doing it, but a dirty way to get the work done. You can also see how the 2018 plot is scaled differently on the x-axis.

  1. Is there a way to enforce that all figures are scaled the same way?
  2. How do I re-write my code such that the figures are plotted by calling a function once?

Upvotes: 0

Views: 324

Answers (1)

K.Cl
K.Cl

Reputation: 1773

When using pyplot, you can plot using an implicit method with plt.plot(), or you can use the explicit method, by creating and calling the figure and axis objects with fig, ax = plt.subplots(). What happened here is, in my view, a side-effect from using the implicit method.

For example, you can use two pd.DataFrame.plot() commands and have them share the same axis by supplying the returned axis to the other function.

foo = pd.DataFrame(dict(a=[1,2,3], b=[4,5,6]))
bar = pd.DataFrame(dict(c=[3,2,1], d=[6,5,4]))
ax = foo.plot()
bar.plot(ax=ax) # ax object is updated
ax.plot([0,3], [1,1], 'k--')

You can also create the figure and axis object previously, and supply as needed. Also, it's perfectly file to have multiple plot commands. Often, my code is 25% work, 75% fiddling with plots. Don't try to be clever and lose on readability.

fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True)
# In this case, axes is a numpy array with 3 axis objects
# You can access the objects with indexing
# All will have the same x range
axes[0].plot([-1, 2], [1,1])
axes[1].plot([-2, 1], [1,1])
axes[2].plot([1,3],[1,1])

So you can combine both of these snippets to your own code. First, create the figure and axes object, then plot each dataframe, but supply the correct axis to them with the keyword ax.

Also, suppose you have three axis objects and they have different x limits. You can get them all, then set the three to have the same minimum value and the same maximum value. For example:

axis_list = [ax1, ax2, ax3] # suppose you created these separately and want to enforce the same axis limits
minimum_x = min([ax.get_xlim()[0] for ax in axis_list])
maximum_x = max([ax.get_xlim()[1] for ax in axis_list])
for ax in axis_list:
    ax.set_xlim(minimum_x, maximum_x)

Upvotes: 1

Related Questions