rpb
rpb

Reputation: 3299

How to change legend text when plotting 3D scatter plot with Matplotlib?

I have a 3D scatter plot which was produced using the following code

import seaborn as sns
import numpy as np

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
# Create an example dataframe
data = {'th': [1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2],
        'pdvalue': [0.5, 0.5, 0.5, 0.5, 0.2,0.2,0.2,0.2,0.3,0.3,0.4,0.1,1,1.1,3,1],
        'my_val': [1.2,3.2,4,5.1,1,2,5.1,1,2,4,1,3,6,6,2,3],
        'name':['a','b','c','d','a','b','c','d','a','b','c','d','a','b','c','d']}
df = pd.DataFrame(data)
# convert unique str into unique int
order_dict = {k: i for i, k in enumerate ( df ['name'])}
df ['name_int'] = df ['name'].map ( order_dict )
data_np=df.to_numpy()

# generate data

x = data_np[:,0]
y = data_np[:,1]
z = data_np[:,2]

# axes instance
fig = plt.figure(figsize=(10,6))
ax = Axes3D(fig)

# get colormap from seaborn
cmap = ListedColormap(sns.color_palette("husl", 256).as_hex())

# plot
sc = ax.scatter(x, y, z, s=40, c=data_np[:,4], marker='o', cmap=cmap, alpha=1)
ax.set_xlabel('th')
ax.set_ylabel('pdvalue')
ax.set_zlabel('my_val')

# legend
plt.legend(*sc.legend_elements(), bbox_to_anchor=(1.05, 1), loc=2)
plt.show()

and this produce

enter image description here

In the above, I had to convert the name into integer type as the para c of the ax.scatter only accept number. As a result, the legend was map according thenumeric value instead of the original name.

May I know how to have the legend in term of name instead of the numerical representation?

Upvotes: 0

Views: 958

Answers (1)

JohanC
JohanC

Reputation: 80509

The code can be simplified making use of pandas to do conversions and selections. By drawing the scatter plot for each 'name' separately, they each can be given a label for the legend.

Here is the adapted code:

import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create an example dataframe
data = {'th': [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2],
        'pdvalue': [0.5, 0.5, 0.5, 0.5, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.4, 0.1, 1, 1.1, 3, 1],
        'my_val': [1.2, 3.2, 4, 5.1, 1, 2, 5.1, 1, 2, 4, 1, 3, 6, 6, 2, 3],
        'name': ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd']}
df = pd.DataFrame(data)

# axes instance
fig = plt.figure(figsize=(10, 6))
ax = Axes3D(fig, auto_add_to_figure=False)
fig.add_axes(ax)

# find all the unique labels in the 'name' column
labels = np.unique(df['name'])
# get palette from seaborn
palette = sns.color_palette("husl", len(labels))

# plot
for label, color in zip(labels, palette):
    df1 = df[df['name'] == label]
    ax.scatter(df1['th'], df1['pdvalue'], df1['my_val'],
               s=40, marker='o', color=color, alpha=1, label=label)
ax.set_xlabel('th')
ax.set_ylabel('pdvalue')
ax.set_zlabel('my_val')

# legend
plt.legend(bbox_to_anchor=(1.05, 1), loc=2)
plt.show()

3d scatter plot with legend

Upvotes: 1

Related Questions