Reputation: 21
I need to replicate this plot in python. Specifically, I am unsure how to (1) change the marker shapes and (2) add class labels above each distribution, as shown in the example figure. Could you provide guidance on how to achieve these two features?
The desired image:
The currently obtained image:
The dataset: https://www.kaggle.com/datasets/uciml/iris
Currently, this is my code:
sns.jointplot(data=my_df, x="sepal length", y="sepal width", hue="class")
Upvotes: 1
Views: 36
Reputation: 2070
In jointplot, you can use joint_kws parameter which will fix the matplotlib parameters of the "middle" plot, here a scatter plot (to deal with the markers). You also need to give to "style" parameter the related column.
For the text you have to do it by hand, writing directly on the top/x marginal axe.
g = sns.jointplot(data=data, x="sepal length", y="sepal width", hue="species",
style=data["species"], palette="tab10", alpha=0.5,
joint_kws={ "markers":('o', 's', 'D')},
)
g.ax_marg_x.text(5.2, .35, "iris setosa", color="C0")
g.ax_marg_x.text(5.8, .25, "iris versicolor", color="C1")
g.ax_marg_x.text(7.0, .15, "iris virginica", color="C2")
g.ax_joint.legend_.remove()
plt.show()
Upvotes: 0
Reputation: 80459
The jointplot
doesn't directly accept parameters to have different markers. The way around that, is to remove the scatter plot created by jointplot
and add a new scatter plot with extra parameters.
The positions of the texts in your example plot aren't placed automatically. Somebody calculated (or tried out) some specific x and y coordinates.
Here is how the code could look like, using seaborn's iris dataset.
from matplotlib import pyplot as plt
import seaborn as sns
iris = sns.load_dataset('iris')
g = sns.jointplot(data=iris, x="sepal_length", y="sepal_width", hue="species")
# clear the central subplot, and plot a scatterplot with extra parameters
g.ax_joint.cla()
sns.scatterplot(data=iris, x="sepal_length", y="sepal_width", hue="species",
style="species", markers=['o', 's', 'D'], lw=1.5, edgecolor='face',
alpha=0.5, ax=g.ax_joint)
for handle, txt, (x, y) in zip(g.ax_joint.legend_.legend_handles, g.ax_joint.legend_.texts,
[(5.2, 0.34), (5.6, 0.27), (7.0, 0.14)]):
g.ax_marg_x.text(x, y, txt.get_text(), color=handle.get_color())
g.ax_joint.legend_.remove() # remove the legend
plt.show()
Upvotes: 1