Reputation: 15814
This is from Chapter 2 in the book Machine Learning In Action
and I am trying to make the plot pictured here:
The author has posted the plot's code here, which I believe may be a bit hacky (he also mentions this code is sloppy since it is out of the book's scope).
Here is my attempt to re-create the plot:
First, the .txt file holding the data is as follows (source: "datingTestSet2.txt" in Ch.2 here):
40920 8.326976 0.953952 largeDoses
14488 7.153469 1.673904 smallDoses
26052 1.441871 0.805124 didntLike
75136 13.147394 0.428964 didntLike
38344 1.669788 0.134296 didntLike
...
Assume datingDataMat
is a numpy.ndarray
of shape `(1000L, 2L) where column 0 is "Frequent Flier Miles Per Year", column 1 is "% Time Playing Video Games", and column 2 is "liter of ice cream consumed per week", as shown in the sample above.
Assume datingLabels
is a list
of ints 1, 2, or 3 meaning "Did Not Like", "Liked in Small Doses", and "Liked in Large Doses" respectively - associated with column 3 above.
Here is the code I have to create the plot (full details for file2matrix
are at the end):
datingDataMat,datingLabels = file2matrix("datingTestSet2.txt")
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot (111)
plt.xlabel("Freq flier miles")
plt.ylabel("% time video games")
# Not sure how to finish this: plt.legend([1, 2, 3], ["did not like", "small doses", "large doses"])
plt.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) # Change marker color and size
plt.show()
The output is here:
My main concern is how to create this legend. Is there a way to do this without needing a direct handle to the points?
Next, I am curious whether I can find a way to switch the colors to match those of the plot. Is there a way to do this without having some kind of "handle" on the individual points?
Also, if interested, here is the file2matrix
implementation:
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines())
returnMat = np.zeros((numberOfLines,3)) #numpy.zeros(shape, dtype=float, order='C')
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3] # FFmiles/yr, % time gaming, L ice cream/wk
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
Upvotes: 3
Views: 2465
Reputation: 10690
Here's an example that mimics the code you already have that shows the approach described in Saullo Castro's example. It also shows how to set the colors in the example. If you want more information on the colors available, see the documentation at http://matplotlib.org/api/colors_api.html
It would also be worth looking at the scatter plot documentation at http://matplotlib.org/1.3.1/api/pyplot_api.html#matplotlib.pyplot.scatter
from numpy.random import rand, randint
from matplotlib import pyplot as plt
n = 1000
# Generate random data
data = rand(n, 2)
# Make a random array to mimic datingLabels
labels = randint(1, 4, n)
# Separate the data according to the labels
data_1 = data[labels==1]
data_2 = data[labels==2]
data_3 = data[labels==3]
# Plot each set of points separately
# 's' is the size parameter.
# 'c' is the color parameter.
# I have chosen the colors so that they match the plot shown.
# With each set of points, input the desired label for the legend.
plt.scatter(data_1[:,0], data_1[:,1], s=15, c='r', label="label 1")
plt.scatter(data_2[:,0], data_2[:,1], s=30, c='g', label="label 2")
plt.scatter(data_3[:,0], data_3[:,1], s=45, c='b', label="label 3")
# Put labels on the axes
plt.ylabel("ylabel")
plt.xlabel("xlabel")
# Place the Legend in the plot.
plt.gca().legend(loc="upper left")
# Display it.
plt.show()
The gray borders should become white if you use plt.savefig
to save the figure to file instead of displaying it.
Remember to run plt.clf()
or plt.cla()
after saving to file to clear the axes so you don't end up replotting the same data on top of itself over and over again.
Upvotes: 2
Reputation: 58895
To create the legend you have to:
give labels to each curve
call the legend()
method from the current AxesSubplot
object, which can be obtained using plt.gca()
, for example.
See the example below:
plt.scatter(datingDataMat[:,0], datingDataMat[:,1],
15.0*np.array(datingLabels), 15.0*np.array(datingLabels),
label='Label for this data')
plt.gca().legend(loc='upper left')
Upvotes: 2