PyRsquared
PyRsquared

Reputation: 7338

How to plot multiple figures in a row using seaborn

I have a dataframe df that looks like this:

df.head()
id        feedback        nlp_model        similarity_score
0xijh4    1               tfidf            0.36
0sdnj7    -1              lda              0.89
kjh458    1               doc2vec          0.78
....

I want to plot similairty_score versus feedback in a boxplot form using seaborn for each of the unique values in the model column: tfidf, lda, doc2vec. My code for this is as follows:

fig, ax = plt.subplots(figsize=(10,8))
ax = sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'])
ax = sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'], color="0.25")

fig, ax = plt.subplots(figsize=(10,8))
ax = sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'])
ax = sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'], color="0.25")

fig, ax = plt.subplots(figsize=(10,8))
ax = sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'])
ax = sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'], color="0.25")

plt.show()

The problem is this creates 3 plots one on top of the other.

enter image description here

How can I generate these same plots but all on a single line, with one axis marking "Similarity Score" on the left most plot only, and "Feedback" axis label directly below each plot?

Upvotes: 12

Views: 35517

Answers (1)

DavidG
DavidG

Reputation: 25363

You are creating new figures, each time you plot. So you can remove all but one of the calls to plt.subplots()

The seaborn swarmplot() and boxplot() accept ax arguments i.e. you can tell it which axes to plot to. Therefore, create your figure, subplots and axes using:

fig, (ax1, ax2, ax3) = plt.subplots(1, 3)

Then you can do something like:

sns.boxplot(x="x_vals", y="y_vals", data=some_data, ax=ax1)

You can then manipulate the axes as you see fit. For example, removing the y axis labels only on certain subplots etc.

fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(10,8))

sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'], ax=ax1)
sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='tfidf'], color="0.25", ax=ax1)

sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'], ax=ax2)
sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='lda'], color="0.25", ax=ax2)

ax2.set_ylabel("")  # remove y label, but keep ticks

sns.boxplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'], ax=ax3)
sns.swarmplot(x="feedback", y="similarity_score", data=df[df.nlp_model=='doc2vec'], color="0.25", ax=ax3)

ax3.set_ylabel("")  # remove y label, but keep ticks

plt.show()

Upvotes: 21

Related Questions