Reputation: 3299
I have a 3D scatter plot which was produced using the following code
import seaborn as sns
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
# Create an example dataframe
data = {'th': [1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2],
'pdvalue': [0.5, 0.5, 0.5, 0.5, 0.2,0.2,0.2,0.2,0.3,0.3,0.4,0.1,1,1.1,3,1],
'my_val': [1.2,3.2,4,5.1,1,2,5.1,1,2,4,1,3,6,6,2,3],
'name':['a','b','c','d','a','b','c','d','a','b','c','d','a','b','c','d']}
df = pd.DataFrame(data)
# convert unique str into unique int
order_dict = {k: i for i, k in enumerate ( df ['name'])}
df ['name_int'] = df ['name'].map ( order_dict )
data_np=df.to_numpy()
# generate data
x = data_np[:,0]
y = data_np[:,1]
z = data_np[:,2]
# axes instance
fig = plt.figure(figsize=(10,6))
ax = Axes3D(fig)
# get colormap from seaborn
cmap = ListedColormap(sns.color_palette("husl", 256).as_hex())
# plot
sc = ax.scatter(x, y, z, s=40, c=data_np[:,4], marker='o', cmap=cmap, alpha=1)
ax.set_xlabel('th')
ax.set_ylabel('pdvalue')
ax.set_zlabel('my_val')
# legend
plt.legend(*sc.legend_elements(), bbox_to_anchor=(1.05, 1), loc=2)
plt.show()
and this produce
In the above, I had to convert the name
into integer
type as the para c
of the ax.scatter
only accept number
. As a result, the legend was map according thenumeric
value instead of the original name
.
May I know how to have the legend in term of name
instead of the numerical representation?
Upvotes: 0
Views: 958
Reputation: 80509
The code can be simplified making use of pandas to do conversions and selections. By drawing the scatter plot for each 'name' separately, they each can be given a label for the legend.
Here is the adapted code:
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# Create an example dataframe
data = {'th': [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2],
'pdvalue': [0.5, 0.5, 0.5, 0.5, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.4, 0.1, 1, 1.1, 3, 1],
'my_val': [1.2, 3.2, 4, 5.1, 1, 2, 5.1, 1, 2, 4, 1, 3, 6, 6, 2, 3],
'name': ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd']}
df = pd.DataFrame(data)
# axes instance
fig = plt.figure(figsize=(10, 6))
ax = Axes3D(fig, auto_add_to_figure=False)
fig.add_axes(ax)
# find all the unique labels in the 'name' column
labels = np.unique(df['name'])
# get palette from seaborn
palette = sns.color_palette("husl", len(labels))
# plot
for label, color in zip(labels, palette):
df1 = df[df['name'] == label]
ax.scatter(df1['th'], df1['pdvalue'], df1['my_val'],
s=40, marker='o', color=color, alpha=1, label=label)
ax.set_xlabel('th')
ax.set_ylabel('pdvalue')
ax.set_zlabel('my_val')
# legend
plt.legend(bbox_to_anchor=(1.05, 1), loc=2)
plt.show()
Upvotes: 1