GKroch
GKroch

Reputation: 83

Subplot from already created plots

I have a function which returns a plot for specific column

def class_distribution(colname):
    df = tweets_best.groupby(["HandLabel", colname]).size().to_frame("size")
    df['percentage'] = df.groupby(level=0).transform(lambda x: (x / x.sum()).round(2))
    df_toPlot = df[["percentage"]]
    
    plot = df_toPlot.unstack().plot.bar()
    plt.legend(df_toPlot.index.get_level_values(level = 1))
    plt.title("{} predicted sentiment distribution".format(colname))
    plt.ylim((0,1))
    plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
    return plot.get_figure()

And example output looks like this

nb = class_distribution("Naive_Bayes")

example_output

I would like to produce 4 plots like this and present them as subplots 2 rows and 2 columns. However if i try

plt.figure()
plt.subplot(1,2,1)
nb
plt.subplot(1,2,2)
sn

I get

example_output2

which is obviously not something I would expect

Upvotes: 3

Views: 3182

Answers (2)

David Collins
David Collins

Reputation: 900

Actually, your output is exactly what you'd expect given your code:

plt.figure()
plt.subplot(1,2,1)
nb
plt.subplot(1,2,2)
sn

In this line plt.subplot(1,2,1) you're specifying two plots in this arrangement: one row and two columns, and placing the plot on the left.

The (1,2,1) specifies (number of rows, number of columns, index to plot).

Since you want subplots arranged 2 by 2, specify (2,2,i) where i is the index. This will arrange your plots:

plt.figure()
plt.subplot(2,2,1)
{plot in upper left}
plt.subplot(2,2,2)
{plot in upper right}
plt.subplot(2,2,3)
{plot in lower left}
plt.subplot(2,2,4)
{plot in lower right}

Additionally, you can handle axes as ImportanceOfBeingEarnest details. You can also share axes and make use of several other parameters and arguments: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.subplot.html

A minimal working example will better identify the problem and get better answers.

Upvotes: 2

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339705

You need to plot to an already existing axes. So your function should take an axes as input:

def class_distribution(colname, ax=None):
    ax = ax or plt.gca()

    df = ...  # create dataframe based on function input

    df.unstack().plot.bar(ax=ax)
    ax.legend(...)
    ax.set_title("{} predicted sentiment distribution".format(colname))
    ax.set_ylim((0,1))
    ax.yaxis.set_major_formatter(PercentFormatter(1))
    return ax

Then, you can create a figure and one or several subplots to plot to:

fig = plt.figure()

ax1 = fig.add_subplot(1,2,1)
class_distribution("colname1", ax=ax1)

ax2 = fig.add_subplot(1,2,2)
class_distribution("colname2", ax=ax2)

Upvotes: 2

Related Questions