Evan Gertis
Evan Gertis

Reputation: 2052

How to to graph multiple lines using sns.scatterplot

I have written a program like so:

# Author: Evan Gertis
# Date  : 11/09
# program: Linear Regression
# Resource: https://seaborn.pydata.org/generated/seaborn.scatterplot.html       
import seaborn as sns
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Step 1: load the data
grades = pd.read_csv("grades.csv") 
logging.info(grades.head())

# Step 2: plot the data
plot = sns.scatterplot(data=grades, x="Hours", y="GPA")
fig = plot.get_figure()
fig.savefig("out.png")

Using the data set

Hours,GPA,Hours,GPA,Hours,GPA
11,2.84,9,2.85,25,1.85
5,3.20,5,3.35,6,3.14
22,2.18,14,2.60,9,2.96
23,2.12,18,2.35,20,2.30
20,2.55,6,3.14,14,2.66
20,2.24,9,3.05,19,2.36
10,2.90,24,2.06,21,2.24
19,2.36,25,2.00,7,3.08
15,2.60,12,2.78,11,2.84
18,2.42,6,2.90,20,2.45

I would like to plot out all of the relationships at this time I just get one plot:

enter image description here

Expected: all relationships plotted

Actual:

enter image description here

I wrote a basic program and I was expecting all of the relationships to be plotted.

Upvotes: 1

Views: 143

Answers (2)

Trenton McKinney
Trenton McKinney

Reputation: 62413

  • There are better options than manually creating a plot for each group of columns
  • Because the columns in the file have redundant names, pandas automatically renames them.

Imports and DataFrame

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# read the data from the file
df = pd.read_csv('d:/data/gpa.csv')

# display(df)
   Hours   GPA  Hours.1  GPA.1  Hours.2  GPA.2
0     11  2.84        9   2.85       25   1.85
1      5  3.20        5   3.35        6   3.14
2     22  2.18       14   2.60        9   2.96
3     23  2.12       18   2.35       20   2.30
4     20  2.55        6   3.14       14   2.66
5     20  2.24        9   3.05       19   2.36
6     10  2.90       24   2.06       21   2.24
7     19  2.36       25   2.00        7   3.08
8     15  2.60       12   2.78       11   2.84
9     18  2.42        6   2.90       20   2.45

Option 1: Chunk the column names

  • This option can be used to plot the data in a loop without manually creating each plot
  • Using this answer from How to iterate over a list in chunks will create a list of column name groups:
    • [Index(['Hours', 'GPA'], dtype='object'), Index(['Hours.1', 'GPA.1'], dtype='object'), Index(['Hours.2', 'GPA.2'], dtype='object')]
# create groups of column names to be plotted together
def chunker(seq, size):
    return [seq[pos:pos + size] for pos in range(0, len(seq), size)]


# function call
col_list = chunker(df.columns, 2)

# iterate through each group of column names to plot
for x, y in chunker(df.columns, 2):
    sns.scatterplot(data=df, x=x, y=y, label=y)

Option 2: Fix the data

# filter each group of columns, melt the result into a long form, and get the value
h = df.filter(like='Hours').melt().value
g = df.filter(like='GPA').melt().value

# get the gpa column names
gpa_cols = df.columns[1::2]

# use numpy to create a list of labels with the appropriate length
labels = np.repeat(gpa_cols, len(df))

# otherwise use a list comprehension to create the labels
# labels = [v for x in gpa_cols for v in [x]*len(df)]

# create a new dataframe
dfl = pd.DataFrame({'hours': h, 'gpa': g, 'label': labels})

# save dfl if desired
dfl.to_csv('gpa_long.csv', index=False)

# plot
sns.scatterplot(data=dfl, x='hours', y='gpa', hue='label')

Plot Result

enter image description here

Upvotes: 1

Lucas M. Uriarte
Lucas M. Uriarte

Reputation: 3101

The origin of the problem is that the columns names in your file are the same and thus when pandas read the columns adds number to the loaded data frame

import seaborn as sns
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

grades = pd.read_csv("grades.csv") 
print(grades.columns)
>>> Index(['Hours', 'GPA', 'Hours.1', 'GPA.1', 'Hours.2', 'GPA.2'], dtype='object')

therefore when you plot the scatter plot you need to give the name of the column names that pandas give

# in case you want all scatter plots in the same figure
plot = sns.scatterplot(data=grades, x="Hours", y="GPA", label='GPA')
sns.scatterplot(data=grades, x='Hours.1', y='GPA.1', ax=plot, label="GPA.1")
sns.scatterplot(data=grades, x='Hours.2', y='GPA.2', ax=plot,  label='GPA.2')
fig = plot.get_figure()
fig.savefig("out.png")

enter image description here

Upvotes: 3

Related Questions