Reputation: 19
I'm really confused about matplotlib in general. I normally just use import matplotlib.pyplot as plt.
And then do everything like plt.figure(), plt.scatter(), plt.xlabel(), plt.show() etc. But then I google how to do something like, map the legend with a colour and I get all these examples that include ax. But there is plt.legend() and the example in matplotlib documentation just shows plt.legend(handles) but doesn't show you what handles is supposed to be. And if I want to do the ax thing then I have to re-write all my code cause I wanted to use plt since it's simpler.
Here's my code:
import matplotlib.pyplot as plt
colmap = {
"domestic": "blue",
"cheetah": "red",
"leopard": "green",
"tiger": "black"
}
colours = []
for i in y_train:
colours.append(colmap[i])
plt.figure(figsize= [15,5])
plt.scatter(X_train[:,0], X_train[:,2],c=colours)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()
Now I want to add a legend that just shows the colours the same as it is in my dictionary. But if I do:
plt.legend(["domestic","cheetah","leopard","tiger"])
it only shows "domestic" in the legend and the colour is red which doesn't actually match how I've colour coded it. Is there a way to do this without re-writing everything with the "ax" thing? And if not, how do I adapt this to ax? Do I just write ax = plt.scatter(....)?
Upvotes: 0
Views: 570
Reputation: 807
RoseGod gave a good example on how to handle your current problem in his answer. For the general difference between using plt and ax to plot things:
plt
calls the pyplot library to do stuff. This usually concerns the last opened figure. This works perfectly fine when you do simple plotting with only one figure/plot at the same time. ax
is an Axes object that refers to one specific (sub)plot and all the elements within that (sub)plot. This gives you full control over everything relating to a (sub)plot especially when plotting several things at the same time in subplots.
See also: matplotlib Axes.plot() vs pyplot.plot()
Upvotes: 0
Reputation: 1234
No data was provided but this code can help you unbderstand how to add color to a scatter plot in matplotlib:
import matplotlib.pyplot as plt import numpy as np
# data for scatter plots
x = list(range(0,30))
y = [i**2 for i in x]
# data for mapping class to color
y_train = ['domestic','cheetah', 'cheetah', 'tiger', 'domestic',
'leopard', 'tiger', 'domestic', 'cheetah', 'domestic',
'leopard', 'leopard', 'domestic', 'domestic', 'domestic',
'domestic', 'cheetah', 'tiger', 'cheetah', 'cheetah',
'domestic', 'domestic', 'domestic', 'cheetah', 'leopard',
'cheetah', 'domestic', 'cheetah', 'tiger', 'domestic']
# color mapper
colmap = {
"domestic": "blue",
"cheetah": "red",
"leopard": "green",
"tiger": "black"
}
# create color array
colors = [colmap[i] for i in y_train]
# plot scatter
plt.figure(figsize=(15,5))
plt.scatter(x, y, c=colors)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()
Output:
Upvotes: 1