Reputation: 93
I need to create a graph in Pylab in which I will plot colored points along the Y axis. The y axis goes from 0 to 100. I also have a list of 100 elements, and the elements are either +1 or -1. This list has to correspond with the Y axis of the graph.
For instance, if the fifth element in the list is +1, I need to plot a green dot on y=5 on the Y axis. if the fifth element in the list is -1, the point has to be red.
I have to do this for all elements in the list.
I have graphed simple graphs in Pylab, but I am totally lost in this case. Any help will be appreciated. Thanks!!
Upvotes: 0
Views: 183
Reputation: 879361
import matplotlib.pyplot as plt
import numpy as np
data = np.array([1,1,-1,-1,1])
cmap = np.array([(1,0,0), (0,1,0)])
uniqdata, idx = np.unique(data, return_inverse=True)
N = len(data)
fig, ax = plt.subplots()
plt.scatter(np.zeros(N), np.arange(1, N+1), s=100, c=cmap[idx])
plt.grid()
plt.show()
yields
Explanation:
If you print out np.unique(data, return_inverse=True)
, you'll see it returns a tuple of arrays:
In [71]: np.unique(data, return_inverse=True)
Out[71]: (array([-1, 1]), array([1, 1, 0, 0, 1]))
The first array says the unique values in data
is -1 and 1. The second array assigns the value 0 wherever data
is -1 and the value 1 wherever data
is 1. Essentially, np.unique
allows us to transform [1,1,-1,-1,1]
to [1, 1, 0, 0, 1]
. Now cmap[idx]
is an array of RGB values:
In [74]: cmap[idx]
Out[74]:
array([[0, 1, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0]])
This is an application of so-called "fancy indexing" on NumPy arrays. cmap[0]
is the first row of cmap
. cmap[1]
is the second row of cmap
. cmap[idx]
is an array such that the ith element in cmap[idx]
is cmap[idx[i]]
. So, you end up with cmap[idx]
being a 2D-array where the ith row is cmap[idx[i]]
. Thus cmap[idx]
can be thought of as a sequence of RGB color values.
If you have more than one set of dots and you wish to plot them in columns, the simplest way I can think of is to call ax.scatter
once for each list of data
:
import matplotlib.pyplot as plt
import numpy as np
def plot_data(ax, data, xval):
N = len(data)
uniqdata, idx = np.unique(data, return_inverse=True)
ax.scatter(np.ones(N)*xval, np.arange(1, N+1), s=100, c=cmap[idx])
cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()
data = np.array([1,1,-1,-1,1])
data2 = np.array([1,-1,1,1,-1])
plot_data(ax, data, 0)
plot_data(ax, data2, 1)
plt.grid()
plt.show()
The nice thing about this is that it is relatively easy to understand. The bad thing about this is that it calls ax.scatter
more than once. If you have lots of data sets it is more efficient to collate your data and call ax.scatter
once. This is faster for Matplotlib, but its a little more complicated to code:
import matplotlib.pyplot as plt
import numpy as np
import itertools as IT
def plot_dots(ax, datasets):
N = sum(len(data) for data in datasets)
x = np.fromiter(
(i for i, data in enumerate(datasets) for j in np.arange(len(data))),
dtype='float', count=N)
y = np.fromiter(
(j for data in datasets for j in np.arange(1, len(data)+1)),
dtype='float', count=N)
c = np.fromiter(
(val for data in datasets
for rgb in cmap[np.unique(data, return_inverse=True)[-1]]
for val in rgb),
dtype='float', count=3*N).reshape(-1,3)
ax.scatter(x, y, s=100, c=c)
cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()
N = 100
datasets = [np.random.randint(2, size=5) for i in range(N)]
plot_dots(ax, datasets)
plt.grid()
plt.show()
References:
Upvotes: 2