user9102437
user9102437

Reputation: 742

Why seaborn's pairplot does not plot the first plot?

I am trying to make a pairplot using sns, but for some reason, it refuses to plot the first one. What may cause this issue?

Here is the fully working code:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.read_csv("http://web.stanford.edu/~oleg2/hse/auto/Auto.csv").dropna()
med = df.mpg.median()
df['mpg01'] = [1 if i > med else 0 for i in df.mpg]
sub = df.drop(columns=['name'])
sns.pairplot(data=sub, x_vars=sub.columns, y_vars=['mpg01'])
plt.show()

Here is the output: enter image description here

Upvotes: 3

Views: 3046

Answers (2)

Vedant Terkar
Vedant Terkar

Reputation: 4682

I know you've already accepted a solution.

But if this can help someone in the future. It is worth mentioning.

I saw a similar issue on the seaborn repository on GitHub

They've identified it as a bug.

from the comment there; It is evident that if you add parameter diag_kind = None while calling sns.pairplot function it should resolve your issue.

In that case your code will look like

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.read_csv("http://web.stanford.edu/~oleg2/hse/auto/Auto.csv").dropna()
med = df.mpg.median()
df['mpg01'] = [1 if i > med else 0 for i in df.mpg]
sub = df.drop(columns=['name'])
sns.pairplot(data=sub, x_vars=sub.columns, 
              y_vars=['mpg01'], diag_kind = None) # Note the change
plt.show()

Hope it'll solve your issue and help others.

Cheers!

Upvotes: 10

Andre
Andre

Reputation: 788

I did some tinkering, and it seems like the issue has to do with how pairplot deals with the diagonal plots (topleft to bottomright), in your case it it the first plot (in the topleft). Usually along the diagonal the same values "meet" on the x and y axis, and a histogram like plot is created (see examples at https://seaborn.pydata.org/generated/seaborn.pairplot.html ).

You could use subplots, that is the most straightforward way to achieve the plot you want.

df = pd.read_csv("http://web.stanford.edu/~oleg2/hse/auto/Auto.csv").dropna()
med = df.mpg.median()
df['mpg01'] = [1 if i > med else 0 for i in df.mpg]
sub = df.drop(columns=['name'])

fig, axs = plt.subplots(1,len(sub.columns), figsize=(3*len(sub.columns),2.5), sharey=True)

for i, col_name in enumerate(sub.columns):
    sns.scatterplot(data=sub, x=col_name, y="mpg01", ax=axs[i])
plt.show()

You'll get this: enter image description here

Or alternatively, use pairplot just like you did but reverse the order of the columns by adding [::-1]

sns.pairplot(data=sub, x_vars=sub.columns[::-1], y_vars='mpg01')

Then you'll get: enter image description here

Upvotes: 2

Related Questions