Reputation: 81
I just upgraded matplotlib to version 3.1.1 and I am experimenting with using legend_elements.
I am making a scatterplot of the top two components from PCA on a dataset of 30,000 flattened, grayscale images. Each image is labeled as one of four master categories (Accessory, Apparel, Footwear, Personal Care). I have color coded the plot by 'master category' by creating a colors column with values from 0 to 3.
I have read the documentation for PathCollection.legend_elements, but I haven't successfully incorporated the 'func' or 'fmt' parameters. https://matplotlib.org/3.1.1/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements
Also, I have tried to follow examples provided: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/scatter_with_legend.html
### create column for color codes
masterCat_codes = {'Accessories':0,'Apparel':1, 'Footwear':2, 'Personal Care':3}
df['colors'] = df['masterCategory'].apply(lambda x: masterCat_codes[x])
### create scatter plot
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter( *full_pca.T, s=.1 , c=df['colors'], label= df['masterCategory'], cmap='viridis')
### using legend_elements
legend1 = ax.legend(*scatter.legend_elements(num=[0,1,2,3]), loc="upper left", title="Category Codes")
ax.add_artist(legend1)
plt.show()
The resulting legend labels are 0, 1, 2, 3. (This happens whether or not I specify label = df['masterCategory'] when defining 'scatter'). I would like labels to say Accessories, Apparel, Footwear, Personal Care.
Is there a way to accomplish this with legend_elements?
Note: As the dataset is large and the preprocessing is computationally heavy, I have written an example that is simpler to reproduce:
fake_data = np.array([[1,1],[1,2],[1,3],[2,1],[2,2],[2,3],[3,1],[3,2],[3,3]])
fake_df = pd.DataFrame(fake_data, columns=['X', 'Y'])
groups = np.array(['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'])
fake_df['Group'] = groups
group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots()
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])
legend = ax.legend(*scatter.legend_elements(num=[0,1,2]), loc="upper left", title="Group \nCodes")
ax.add_artist(legend)
plt.show()
Upvotes: 4
Views: 8688
Reputation: 81
Solution Thanks to ImportanceOfBeingErnest
.legend_elements
returns legend handles and labels for a PathCollection
.
handles = scatter.legend_elements(num=[0,1,2,3])[0]
because the handles are the first object returned by the method.group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])
handles = scatter.legend_elements(num=[0,1,2,3])[0] # extract the handles from the existing scatter plot
ax.legend(title='Group\nCodes', handles=handles, labels=group_codes.keys())
plt.show()
Upvotes: 4