Reputation: 81
I am new to matplotlib and trying to plot this liner regression with customized color for a specific independent variable:
colors=['red','blue','green','black']
X=array([[1000],[2000],[3000],[4500]]
y=array([[200000],[200000],[200000],[200000]]
plt.scatter(X, y, color = colors[0])
plt.plot(X, lin_reg.predict(X), color = 'blue')
plt.xlabel('X')
plt.ylabel('y')
plt.show()
I need to set the color to black when X==3000 so I am using np.where:
colors_z=(np.where(X==3000,colors[4],colors[0]))
plt.scatter(X, y, color = colors_z)
But I am getting color error. any Idea what I am doing wrong? Thanks
Upvotes: 1
Views: 997
Reputation: 91
You've set colors_z to include colors[4] but there are only 4 colors in the list colors. The index for colors_z should be out of range. I'd dump the np.where in favor of a simple if statement or ternary operator. Something like:
# ternary operator example
plt.scatter(x, y, color = [colors[3] if x == 3000 else colors[0] for i in x])
Note that this will only work when x is exactly == 3000, but it doesn't throw a syntactical error on my console, so it should work in your regression.
Upvotes: 1
Reputation: 1015
I think this does what you're looking for; using np.where
is a bit overkill for this purpose:
X = [1000, 2000, 3000, 4500]
y = [200000, 3000, 200000, 200000]
colors = list(map(lambda x: 'r' if x == 3000 else 'b', X))
plt.scatter(X, y, color=colors)
plt.xlabel('X')
plt.ylabel('y')
plt.show()
Upvotes: 1