Pulpito
Pulpito

Reputation: 23

Redundant Legends: Matplotlib

My scatter plot has redundant legends. Here is the image of my plot. enter image description here

On this issue, I already checked the following existing question at StackOverflow: too many legend with array column data in matplotlib

Nevertheless, it did not help. I presume that I am encountering a totally different issue. Please advise me how to resolve this.

Here is my code:

import matplotlib.cm as cm
colors = cm.rainbow(np.linspace(0, 1, N_Clus))
cluster_labels_2 = list(range(1, N_Clus+1))
print("cluster_labels: ", cluster_labels_2)
# Create a figure
plt.figure(figsize=(15,8))
s=0
for color, label in zip(colors, np.asarray(cluster_labels_2).flatten()):
    subset = WorkingDF2[WorkingDF2.Cluster == label]    
    for i in subset.index:
        x=np.asarray(subset["Standardized COVID-19 Index"][i]).flatten()
        y=np.asarray(subset["Standardized CSS Index"][i]).flatten() 
        plt.text(x, y, str(subset['Neighbourhood'][i]), rotation=25) 
        s += 1
        plt.scatter(x, y, c=np.array([color]), label='cluster'+str(label),alpha=0.5)
plt.legend(loc='lower right', fontsize=15)
plt.xlabel('Standardized COVID-19 Index', fontsize=18)
plt.ylabel('Standardized CSS Index', fontsize=18)
plt.title("[Hierarchical Clustering: {} Cluster] \n 
 Mapping of Non-Outlier Neighbourhoods \n 
 onto Standardized CSS-COVID19 Indices Space \n
 ".format(N_Clus), fontsize=18)
print('# of Neighbours: ', s)

Upvotes: 0

Views: 355

Answers (1)

Guimoute
Guimoute

Reputation: 4649

The issue comes from the line

plt.scatter(x, y, c=np.array([color]), label='cluster'+str(label),alpha=0.5)

Here, you give the colored dot a label 'cluster' + str(label) even if such a label already exists, so plt.legend() will create many identical legend elements. I would keep track of the previous labels and set the current plot's one to None if it's not new so that the plt.legend() ignores it.

Notice that your naming choice can be a little confusing as matplotlib uses "label" as the name of a curve that appears in the legend, while you use it as a cluster number. Can we just call it cluster_number?

Here is the implementation:

import matplotlib.cm as cm

colors = cm.rainbow(np.linspace(0, 1, N_Clus))
cluster_labels_2 = list(range(1, N_Clus+1))
print("cluster_labels: ", cluster_labels_2)

# Create a figure.
plt.figure(figsize=(15, 8))
s=0
clusters_already_in_the_legend = []
for color, cluster_number in zip(colors, np.asarray(cluster_labels_2).flatten()):
    subset = WorkingDF2[WorkingDF2.Cluster == cluster_number]    
    for i in subset.index:
        x = np.asarray(subset["Standardized COVID-19 Index"][i]).flatten()
        y = np.asarray(subset["Standardized CSS Index"][i]).flatten() 
        plt.text(x, y, str(subset['Neighbourhood'][i]), rotation=25) 
        s += 1

        # Keeping track of the labels so that we don't legend them multiple times.
        if cluster_number not in clusters_already_in_the_legend:
            clusters_already_in_the_legend.append(cluster_number)
            label = f"Cluster {cluster_number}"
        else:
            label = None
        plt.scatter(x, y, c=np.array([color]), label=label, alpha=0.5)

plt.legend(loc='lower right', fontsize=15)
plt.xlabel('Standardized COVID-19 Index', fontsize=18)
plt.ylabel('Standardized CSS Index', fontsize=18)
plt.title("[Hierarchical Clustering: {} Cluster] \n 
 Mapping of Non-Outlier Neighbourhoods \n 
 onto Standardized CSS-COVID19 Indices Space \n
 ".format(N_Clus), fontsize=18)
print('# of Neighbours: ', s)

Upvotes: 1

Related Questions