user1960836
user1960836

Reputation: 1782

How to create plots inside a subplot figure?

I'm trying to draw 5 different, horizontally aligned, plots, with different k values, so they can be compared.

I managed to draw 1 figure. But when looping 5 times, only 1 drawing comes up:

from sklearn.neighbors import KNeighborsClassifier
import mglearn
import matplotlib.pyplot as plt

clf = KNeighborsClassifier(n_neighbors=3)
clf.fit(X_test, c_test)

for counter in range(5):    
    mglearn.discrete_scatter(X_test[:,0], X_test[:,1], c_test)
    plt.legend(["Class 0", "Class 1"], loc=4)
    plt.xlabel("First feature")
    plt.ylabel("Second feature")

How can I display 5 horizontally aligned plots?

Upvotes: 1

Views: 247

Answers (2)

Trenton McKinney
Trenton McKinney

Reputation: 62583

  • Use plt.subplots, and specify the number of columns with the ncols parameter.
  • When creating the plot, use counter to index the correct ax, with ax=ax[counter]
  • Add plt.tight_layout() to add spacing between the plots, otherwise ylabels may be overlapping with the adjacent plot.
fig, ax = plt.subplots(ncols=5, figsize=(20, 6))  # create subplot with x number of columns
for counter in range(5):
    mglearn.discrete_scatter(X_test[:,0], X_test[:,1], c_test, ax=ax[counter])
    plt.legend(["Class 0", "Class 1"], loc=4)
    plt.xlabel("First feature")
    plt.ylabel("Second feature")
plt.tight_layout()  # this will help create proper spacing between the plots.
plt.show()

Example

import pandas as pd
import numpy as np

# sinusoidal sample data
sample_length = range(1, 4+1)
rads = np.arange(0, 2*np.pi, 0.01)
data = np.array([np.sin(t*rads) for t in sample_length])
df = pd.DataFrame(data.T, index=pd.Series(rads.tolist(), name='radians'), columns=[f'freq: {i}x' for i in sample_length])

# plot with subplots
fig, ax = plt.subplots(ncols=4, figsize=(20, 5))
for i, col in enumerate(df.columns):
    d = pd.DataFrame(df[col])
    sns.lineplot(x=d.index, y=col, data=d, ax=ax[i])

plt.tight_layout()
plt.show()

enter image description here

Upvotes: 1

tenhjo
tenhjo

Reputation: 4537

You could plot everything in a single figure on different axes:

fig, ax = plt.subplots(nrows=1, ncols=5):
for counter in range(5):    
    mglearn.discrete_scatter(X_test[:,0], X_test[:,1], c_test, ax=ax[counter])

Upvotes: 1

Related Questions