shakedzy
shakedzy

Reputation: 2893

Matplotlib: how to plot clusters with different colors and annotations?

The Matplotlib is highly confusing to me. I have a pd.DataFrame with columns x,y an cluster. I wish to plot this data on an x-y plot, where every cluster has a different color and an annotation of which cluster that is.

I'm capable of doing these separately. To plot the data with different colors:

for c in np.unique(data['cluster'].tolist()):
    df = data[data['c'].isin([c])]
    plt.plot(df['x'].tolist(),df['y'].tolist(),'o')
plt.show()

This yields:

enter image description here

And annotations:

fig, ax = plt.subplots()
x = df['x'].tolist()
y = df['y'].tolist()
ax.scatter(x, y)
for i, txt in enumerate(data['cluster'].tolist()):
    ax.annotate(txt, (x[i],y[i]))
plt.show()

This yields:

enter image description here

How do I combine the two? I don't understand how to mix the figure/axes/plot APIs all together..


Sample data:

pd.DataFrame({'c': ['News',   'Hobbies & Interests',   'Arts & Entertainment',   'Internal Use',   'Business',   'Internal Use',   'Internal Use',   'Ad Impression Fraud',   'Arts & Entertainment',   'Adult Content',   'Arts & Entertainment',   'Internal Use',   'Internal Use',   'Reference',   'News',   'Shopping',   'Food & Drink',   'Internal Use',   'Internal Use',   'Reference'],  
'x': [-95.44078826904297,   127.71454620361328,   -491.93121337890625,   184.5579071044922,   -191.46273803710938,   95.22545623779297,   272.2229919433594,   -67.099365234375,   -317.60797119140625,   -175.90196228027344,   -491.93121337890625,   214.3858642578125,   184.5579071044922,   346.4012756347656,   -151.8809051513672,   431.6130676269531,   -299.4017028808594,   184.5579071044922,   184.5579071044922,   241.29026794433594],  
'y': [-40.87070846557617,   245.00514221191406,   43.07831954956055,   -458.2991638183594,   270.4497985839844,   -453.2981262207031,   -439.6551513671875,   -206.3104248046875,   205.25787353515625,   -58.520164489746094,   43.07831954956055,   -182.91664123535156,   -458.2991638183594,   19.559282302856445,   -281.3316650390625,   103.6922378540039,   280.2445373535156,   -458.2991638183594,   -458.2991638183594,   -113.96920776367188]})

Upvotes: 1

Views: 5600

Answers (2)

shakedzy
shakedzy

Reputation: 2893

Surprisingly, combining the two methods also solved it:

fig, ax = plt.subplots()
fig.set_size_inches(20,20)
x = df['x'].tolist()
y = df['y'].tolist()
ax.scatter(x, y)
for i, txt in enumerate(data['c'].tolist()):
    ax.annotate(txt, (x[i],y[i]))
for c in np.unique(data['c'].tolist()):
    df = tsne_df[data['c'].isin([c])]
    plt.plot(data['x'].tolist(),data['y'].tolist(),'o')
plt.show()

Upvotes: 0

Nico Albers
Nico Albers

Reputation: 1696

I'll use df.plot.scatter syntax for comfortable reasons, but should be (nearly) the same as ax.scatter.

Okay, so using your example data, you can specify a cmap like described in the docs :

import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame({'c': ['News',   'Hobbies & Interests',   'Arts & Entertainment',   'Internal Use',   'Business',   'Internal Use',   'Internal Use',   'Ad Impression Fraud',   'Arts & Entertainment',   'Adult Content',   'Arts & Entertainment',   'Internal Use',   'Internal Use',   'Reference',   'News',   'Shopping',   'Food & Drink',   'Internal Use',   'Internal Use',   'Reference'],  
'x': [-95.44078826904297,   127.71454620361328,   -491.93121337890625,   184.5579071044922,   -191.46273803710938,   95.22545623779297,   272.2229919433594,   -67.099365234375,   -317.60797119140625,   -175.90196228027344,   -491.93121337890625,   214.3858642578125,   184.5579071044922,   346.4012756347656,   -151.8809051513672,   431.6130676269531,   -299.4017028808594,   184.5579071044922,   184.5579071044922,   241.29026794433594],  
'y': [-40.87070846557617,   245.00514221191406,   43.07831954956055,   -458.2991638183594,   270.4497985839844,   -453.2981262207031,   -439.6551513671875,   -206.3104248046875,   205.25787353515625,   -58.520164489746094,   43.07831954956055,   -182.91664123535156,   -458.2991638183594,   19.559282302856445,   -281.3316650390625,   103.6922378540039,   280.2445373535156,   -458.2991638183594,   -458.2991638183594,   -113.96920776367188]})

df['col'] = df.c.astype('category').cat.codes

cmap = plt.cm.get_cmap('jet', df.c.nunique())
ax = df.plot.scatter(
    x='x',y='y', c='col',
    cmap=cmap
)
plt.show()

Here get_cmap takes a cmap name (You can find the names of various maps on this example page) and

an integer giving the number of entries desired in the lookup table,

The above code results in the following: enter image description here

If you want to add your annotations and suppress the colorbar, use:

ax = df.plot.scatter(
    x='x',y='y', c='col',
    cmap=cmap, colorbar=False
)
for i, txt in enumerate(df['c'].tolist()):
    ax.annotate(txt, (df.x[i], df.y[i]))
plt.show()

And get the following: enter image description here

Hint: Use the "s" param in plt.scatter(x,y,s=None, c=None, **kwds) to change the size if this is too small.

Upvotes: 2

Related Questions