modulitos
modulitos

Reputation: 15814

Python: How to create a legend using an example

This is from Chapter 2 in the book Machine Learning In Action and I am trying to make the plot pictured here:

plot

The author has posted the plot's code here, which I believe may be a bit hacky (he also mentions this code is sloppy since it is out of the book's scope).

Here is my attempt to re-create the plot:

First, the .txt file holding the data is as follows (source: "datingTestSet2.txt" in Ch.2 here):

40920   8.326976    0.953952    largeDoses
14488   7.153469    1.673904    smallDoses
26052   1.441871    0.805124    didntLike
75136   13.147394   0.428964    didntLike
38344   1.669788    0.134296    didntLike
...

Assume datingDataMat is a numpy.ndarray of shape `(1000L, 2L) where column 0 is "Frequent Flier Miles Per Year", column 1 is "% Time Playing Video Games", and column 2 is "liter of ice cream consumed per week", as shown in the sample above.

Assume datingLabels is a list of ints 1, 2, or 3 meaning "Did Not Like", "Liked in Small Doses", and "Liked in Large Doses" respectively - associated with column 3 above.

Here is the code I have to create the plot (full details for file2matrix are at the end):

datingDataMat,datingLabels = file2matrix("datingTestSet2.txt")
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot (111)
plt.xlabel("Freq flier miles")
plt.ylabel("% time video games")
# Not sure how to finish this: plt.legend([1, 2, 3], ["did not like", "small doses", "large doses"])
plt.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) # Change marker color and size 
plt.show()

The output is here:

enter image description here

My main concern is how to create this legend. Is there a way to do this without needing a direct handle to the points?

Next, I am curious whether I can find a way to switch the colors to match those of the plot. Is there a way to do this without having some kind of "handle" on the individual points?

Also, if interested, here is the file2matrix implementation:

def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())
    returnMat = np.zeros((numberOfLines,3)) #numpy.zeros(shape, dtype=float, order='C') 
    classLabelVector = []
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3] # FFmiles/yr, % time gaming, L ice cream/wk
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector

Upvotes: 3

Views: 2465

Answers (2)

IanH
IanH

Reputation: 10690

Here's an example that mimics the code you already have that shows the approach described in Saullo Castro's example. It also shows how to set the colors in the example. If you want more information on the colors available, see the documentation at http://matplotlib.org/api/colors_api.html

It would also be worth looking at the scatter plot documentation at http://matplotlib.org/1.3.1/api/pyplot_api.html#matplotlib.pyplot.scatter

from numpy.random import rand, randint
from matplotlib import pyplot as plt
n = 1000
# Generate random data
data = rand(n, 2)
# Make a random array to mimic datingLabels
labels = randint(1, 4, n)
# Separate the data according to the labels
data_1 = data[labels==1]
data_2 = data[labels==2]
data_3 = data[labels==3]
# Plot each set of points separately
# 's' is the size parameter.
# 'c' is the color parameter.
# I have chosen the colors so that they match the plot shown.
# With each set of points, input the desired label for the legend.
plt.scatter(data_1[:,0], data_1[:,1], s=15, c='r', label="label 1")
plt.scatter(data_2[:,0], data_2[:,1], s=30, c='g', label="label 2")
plt.scatter(data_3[:,0], data_3[:,1], s=45, c='b', label="label 3")
# Put labels on the axes
plt.ylabel("ylabel")
plt.xlabel("xlabel")
# Place the Legend in the plot.
plt.gca().legend(loc="upper left")
# Display it.
plt.show()

The gray borders should become white if you use plt.savefig to save the figure to file instead of displaying it. Remember to run plt.clf() or plt.cla() after saving to file to clear the axes so you don't end up replotting the same data on top of itself over and over again.

Upvotes: 2

Saullo G. P. Castro
Saullo G. P. Castro

Reputation: 58895

To create the legend you have to:

  • give labels to each curve

  • call the legend() method from the current AxesSubplot object, which can be obtained using plt.gca(), for example.

See the example below:

plt.scatter(datingDataMat[:,0], datingDataMat[:,1],
            15.0*np.array(datingLabels), 15.0*np.array(datingLabels),
            label='Label for this data')
plt.gca().legend(loc='upper left')

Upvotes: 2

Related Questions