Reputation: 742
I am trying to make a pairplot using sns
, but for some reason, it refuses to plot the first one. What may cause this issue?
Here is the fully working code:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
df = pd.read_csv("http://web.stanford.edu/~oleg2/hse/auto/Auto.csv").dropna()
med = df.mpg.median()
df['mpg01'] = [1 if i > med else 0 for i in df.mpg]
sub = df.drop(columns=['name'])
sns.pairplot(data=sub, x_vars=sub.columns, y_vars=['mpg01'])
plt.show()
Upvotes: 3
Views: 3046
Reputation: 4682
I know you've already accepted a solution.
But if this can help someone in the future. It is worth mentioning.
I saw a similar issue on the seaborn
repository on GitHub
They've identified it as a bug.
from the comment there; It is evident that if you add parameter diag_kind = None
while calling sns.pairplot
function it should resolve your issue.
In that case your code will look like
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
df = pd.read_csv("http://web.stanford.edu/~oleg2/hse/auto/Auto.csv").dropna()
med = df.mpg.median()
df['mpg01'] = [1 if i > med else 0 for i in df.mpg]
sub = df.drop(columns=['name'])
sns.pairplot(data=sub, x_vars=sub.columns,
y_vars=['mpg01'], diag_kind = None) # Note the change
plt.show()
Hope it'll solve your issue and help others.
Cheers!
Upvotes: 10
Reputation: 788
I did some tinkering, and it seems like the issue has to do with how pairplot deals with the diagonal plots (topleft to bottomright), in your case it it the first plot (in the topleft). Usually along the diagonal the same values "meet" on the x and y axis, and a histogram like plot is created (see examples at https://seaborn.pydata.org/generated/seaborn.pairplot.html ).
You could use subplots, that is the most straightforward way to achieve the plot you want.
df = pd.read_csv("http://web.stanford.edu/~oleg2/hse/auto/Auto.csv").dropna()
med = df.mpg.median()
df['mpg01'] = [1 if i > med else 0 for i in df.mpg]
sub = df.drop(columns=['name'])
fig, axs = plt.subplots(1,len(sub.columns), figsize=(3*len(sub.columns),2.5), sharey=True)
for i, col_name in enumerate(sub.columns):
sns.scatterplot(data=sub, x=col_name, y="mpg01", ax=axs[i])
plt.show()
Or alternatively, use pairplot just like you did but reverse the order of the columns by adding [::-1]
sns.pairplot(data=sub, x_vars=sub.columns[::-1], y_vars='mpg01')
Upvotes: 2