The Anonymous
The Anonymous

Reputation: 19

Matplotlib - How can I add labels to legend

Here I am trying to separate the data with the factor male or not by plotting Age on x-axis and Fare on y-axis and I want to display two labels in the legend differentiating male and female with respective colors.Can anyone help me do this.

Code:

import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv('https://sololearn.com/uploads/files/titanic.csv')
df['male']=df['Sex']=='male'
sc1= plt.scatter(df['Age'],df['Fare'],c=df['male'])
plt.legend()
plt.show()

Upvotes: 0

Views: 1437

Answers (3)

Dhiraj Bansal
Dhiraj Bansal

Reputation: 455

This can be achieved by segregating the data in two separate dataframe and then, label can be set for these dataframe.

import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv('https://sololearn.com/uploads/files/titanic.csv')
subset1 = df[(df['Sex'] == 'male')]
subset2 = df[(df['Sex'] != 'male')]
plt.scatter(subset1['Age'], subset1['Fare'], label = 'Male')
plt.scatter(subset2['Age'], subset2['Fare'], label = 'Female')
plt.legend()
plt.show()

enter image description here

Upvotes: 0

Ynjxsjmh
Ynjxsjmh

Reputation: 30032

PathCollection.legend_elements method can be used to steer how many legend entries are to be created and how they should be labeled.

import matplotlib.pyplot as plt
import pandas as pd

df = pd.read_csv('https://sololearn.com/uploads/files/titanic.csv')
df['male'] = df['Sex']=='male'

sc1= plt.scatter(df['Age'], df['Fare'], c=df['male'])

plt.legend(handles=sc1.legend_elements()[0], labels=['male', 'female'])

plt.show()

enter image description here

Legend guide and Scatter plots with a legend for reference.

Upvotes: 1

Nikhil Kumar
Nikhil Kumar

Reputation: 1232

You could use the seaborn library which builds on top of matplotlib to perform the exact task you require. You can scatterplot 'Age' vs 'Fare' and colour code it by 'Sex' by just passing the hue parameter in sns.scatterplot, as follows:

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure()

# No need to call plt.legend, seaborn will generate the labels and legend
# automatically.
sns.scatterplot(df['Age'], df['Fare'], hue=df['Sex'])

plt.show()

Seaborn generates nicer plots with less code and more functionality.

You can install seaborn from PyPI using pip install seaborn.

Refer: Seaborn docs

Upvotes: 1

Related Questions