dg S
dg S

Reputation: 85

Seaborn Catplot is throwing error: truth value is ambiguous

I am trying to do a catplot using seaborn library for all the categorical variables in my dataframe but I ma getting error for ambiguous truth value. It generally happens with "&" value but I am unable to get the root cause here. My target is continuous variable.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

target = df[target_col]
features = df[df.columns.difference([target_col])]

cat_cols = features.select_dtypes(include=['object']).columns.to_list()

fig, axes = plt.subplots(round(len(cat_cols) / 3), 3, figsize=(15, 15))
for i, ax in enumerate(fig.axes):
        if i < len(cat_cols):
           sns.catplot(x=cat_cols[i], y=target, kind='bar',data=df, ax = ax)

But I am getting the below error. Which part is causing this value error?

ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

Upvotes: 1

Views: 1434

Answers (1)

StupidWolf
StupidWolf

Reputation: 46908

sns.catplot is a grid level plot, so you should not slot it into a subplot. You can use a facetgrid with barplot:

For example this is your data:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.DataFrame({'y':np.random.uniform(0,1,50),'A':np.random.choice(['a1','a2'],50),
                  'B':np.random.choice(['b1','b2'],50),'C':np.random.randint(0,10,50),
                  'D':np.random.choice(['d1','d2'],50),'E':np.random.choice(['e1','e2'],50)})
target_col = "y"
cat_cols = df.columns[df.dtypes==object]

seaborn works better with long format, so you can pivot your data long like this:

df.melt(id_vars=target_col,value_vars=cat_cols)

    y      variable value
0   0.606734    A   a1
1   0.603324    A   a2
2   0.938280    A   a2
3   0.718703    A   a1
4   0.808013    A   a1

The column variable now defines the facet to plot and the x-axis is your value. We call this directly:

g = sns.FacetGrid(df.melt(id_vars=target_col,value_vars=cat_cols), 
col='variable', sharex=False,col_wrap=3)
g.map_dataframe(sns.barplot, x="value", y="y")

enter image description here

Upvotes: 2

Related Questions