pbreach
pbreach

Reputation: 17017

Changing color and marker of each point using seaborn jointplot

I have this code slightly modified from here :

import seaborn as sns
sns.set(style="darkgrid")

tips = sns.load_dataset("tips")
color = sns.color_palette()[5]
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12), color='k', size=7)

g.set_axis_labels('total bill', 'tip', fontsize=16)

and I get a nice looking plot - However, for my case I need to be able to change the color AND format of each individual point.

I've tried using the keywords, marker, style, and fmt, but I get the error TypeError: jointplot() got an unexpected keyword argument.

What is the correct way to do this? I'd like to avoid calling sns.JointGrid and plotting the data and marginal distributions manually..

Upvotes: 19

Views: 45034

Answers (6)

Vincent Jeanselme
Vincent Jeanselme

Reputation: 49

You can also directly precise it in the list of arguments, thanks to the keyword : joint_kws (tested with seaborn 0.8.1). If needed, you can also change the properties of the marginal with marginal_kws

So your code becomes :

import seaborn as sns
colors = np.random.random((len(tips),3))
markers = (['x','o','v','^','<']*100)[:len(tips)]

sns.jointplot("total_bill", "tip", data=tips, kind="reg",
    joint_kws={"color":colors, "marker":markers})

Upvotes: 4

Claire
Claire

Reputation: 719

  1. In seaborn/categorical.py, find def swarmplot.
  2. Add parameter marker='o' before **kwargs
  3. In kwargs.update, add marker=marker.

Then add e.g. marker='x' as a parameter when plotting with sns.swarmplot() as you would with Matplotlib plt.scatter().

Just came across the same need, and having marker as a kwarg did not work. So I had a brief look. We can set other parameters in similar ways. https://github.com/ccneko/seaborn/blob/master/seaborn/categorical.py

Only a small change needed here, but here's the GitHub forked page for quick reference ;)

Upvotes: 2

Vlamir
Vlamir

Reputation: 421

Another option is to use JointGrid, since jointplot is a wrapper that simplifies its usage.

import matplotlib.pyplot as plt
import seaborn as sns

tips = sns.load_dataset("tips")

g = sns.JointGrid("total_bill", "tip", data=tips)
g = g.plot_joint(plt.scatter, c=np.random.random((len(tips), 3)))
g = g.plot_marginals(sns.distplot, kde=True, color="k")

Upvotes: 1

riri
riri

Reputation: 509

The other two answers are complex extravagances (actually, they're by people who truly understand what's going on under the hood).

Here's an answer by someone who's just guessing. It works though!

tips = sns.load_dataset("tips")
g = sns.jointplot("total_bill", "tip", data=tips,
              c=tips.day.cat.codes, cmap='Set1', stat_func=None,
              xlim=(0, 60), ylim=(0, 12))

Upvotes: -2

Max Shron
Max Shron

Reputation: 966

The accepted answer is too complicated. plt.sca() can be used to do this in a simpler way:

import matplotlib.pyplot as plt
import seaborn as sns

tips = sns.load_dataset("tips")
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12))


g.ax_joint.cla() # or g.ax_joint.collections[0].set_visible(False), as per mwaskom's comment

# set the current axis to be the joint plot's axis
plt.sca(g.ax_joint)

# plt.scatter takes a 'c' keyword for color
# you can also pass an array of floats and use the 'cmap' keyword to
# convert them into a colormap
plt.scatter(tips.total_bill, tips.tip, c=np.random.random((len(tips), 3)))

Upvotes: 17

pbreach
pbreach

Reputation: 17017

Solving this problem is almost no different than that from matplotlib (plotting a scatter plot with different markers and colors), except I wanted to keep the marginal distributions:

import seaborn as sns
from itertools import product
sns.set(style="darkgrid")

tips = sns.load_dataset("tips")
color = sns.color_palette()[5]
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12), color='k', size=7)

#Clear the axes containing the scatter plot
g.ax_joint.cla()

#Generate some colors and markers
colors = np.random.random((len(tips),3))
markers = ['x','o','v','^','<']*100

#Plot each individual point separately
for i,row in enumerate(tips.values):
    g.ax_joint.plot(row[0], row[1], color=colors[i], marker=markers[i])

g.set_axis_labels('total bill', 'tip', fontsize=16)

Which gives me this:

enter image description here

The regression line is now gone, but this is all I needed.

Upvotes: 26

Related Questions