Reputation: 375
I have a list of data X
which is an Nx2 matrix. I want to plot the first elements of each X
against all the second elements of X
.
First I have separated all the first and second elements of X
into their own lists: X_comp1
, X_comp2
I also have an Nx1 list of categories(cat), which shows which category the elements of X belong, i.e. if cat[i] = 3
, it means that X[i]
belongs in category 3.
I would like to use a different coloured point in the scatter plot for each category.
So far I only been able to achieve this through hard coding, but this will become very inefficient when there is more categories. Here's my code(it assumes there will be 5 categories):
#sample of X data
X = [[-0.13085379 -0.05958517],[ 0.02593188 -0.17576942],[-0.12505032 -0.02709171],[ 0.09790905 -0.18046944],[ 0.06437596 -0.20600157],[ 0.16287853 -0.2955353 ],[-0.52093842 0.33463338],[-0.03240038 -0.05431373],[-0.09645192 -0.14241157],[ 0.0807245 -0.26893815]]
X_comp1 = []#hold all the first components of X
X_comp2 = []#hold all the second components of X
cat = [1,3,2,1,5,3,2,4,4,1]
#for testing just use 10 values, full file has over 3000 entries and 50 categories
for i in range(10):
X_comp1.append(X[i][0])
X_comp2.append(X[i][1])
for x1,x2,c in zip(X_comp1,X_comp2,cat):
if c == 1:
plt.scatter(x1,x2,c = 'b')
elif c == 2:
plt.scatter(x1,x2,c = 'g')
elif c == 3:
plt.scatter(x1,x2,c = 'r')
elif c == 4:
plt.scatter(x1,x2,c = 'c')
elif c == 5:
plt.scatter(x1,x2,c = 'm')
plt.legend([1,2,3,4,5])
plt.show()
I would like to make it more flexible to the number of categories, so that I don't have to end up writing loads of if statements for each category.
In order to achieve this I thought of having a list of colours:
colours = ["b", "g", "r", "c", "m",...]#number of colours depends on no. of categories
#This is the only element which I would like remain hard coded, as I want to choose the colours
Where each colour corresponds to a category. Then the program iterates through all the data and plots each point accordingly to the category. But I'm not sure how this could be implemented.
Upvotes: 1
Views: 181
Reputation: 2052
For a pretty plot, you could also work with seaborn
:
import seaborn as sns
import pandas as pd
sns.set()
df = pd.DataFrame({'X_1': X_comp1,'X_2':X_comp2, 'category':cat})
sns.scatterplot(data=df,x='X_1', y='X_2', hue='category')
If you care about which category should have what color, you can pass the palette
parameter with your own category-color dict:
my_palette = {1: 'b', 2: 'g', 3: 'r', 4: 'c', 5: 'm'}
sns.scatterplot(data=df,x='X_1', y='X_2', hue='category', palette=my_palette)
There are also a bunch of predefined palettes if you're not satisfied with seaborn's default choice.
Upvotes: 1
Reputation: 9051
Try this
color_dict = {1: 'b', 2: 'g', 3: 'r', 4: 'c', 5: 'm'}
for x1, x2, c in zip(X_comp1, X_comp2, cat):
if c in color_dict:
plt.scatter(x1, x2, c = color_dict[c])
plt.legend(list(color_dict.keys()))
plt.show()
Instead of checking for each value of c
by using dictionary we can remove all the if statements
Upvotes: 1