Brian Droncheff
Brian Droncheff

Reputation: 69

Seaborn plot list of values vs. their indexes

I'm trying to tap in to the seaborn potential for graphing a line plot but it seems I keep getting the error "TypeError: unhashable type: 'list'" Which I get that there is a problem with the dataframe values of Accuracies being lists, though I don't know how to correctly graph the lists against their indexes (which I also made in to a list called 'Accuracy_Indexes'). Basically I'm just trying to print each of the 'Accuracies' on the X-axis and their indices on the Y axis with the hue being the 'Bin'. So it is two lists of the same size for the X and Y axes. Not sure if Catplot is the right plot to get that such as a lineplot but I get warnings that it is old and going to be removed in the future.

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

N=7
Accuracy_Indexes = list(range(1, 5))
df = pd.DataFrame({'Bin': np.random.randint(1,10,N)})
TheList = []
for item in range(N):
    TheList.append([item+1, np.random.random_sample(size=len(Accuracy_Indexes)).tolist()])
print(TheList)

df = pd.DataFrame.from_records(TheList,columns=['Bin','Accuracies'])
print(df)

g = sns.catplot(x=Accuracy_Indexes, y='Accuracies', hue="Bin",
                data=df, height=5, aspect=.8)
plt.show()

Thanks for any advice.

Upvotes: 1

Views: 2820

Answers (1)

JohanC
JohanC

Reputation: 80459

Pandas' explode() can be used to convert the lists into separate rows. Tiling the Accuracy_Indexes in the same order can create a new column with these indices.

Here is some example code to illustrate the idea:

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

N = 7
Accuracy_Indexes = list(range(1, 5))
df = pd.DataFrame({'Bin': np.random.randint(1, 10, N)})
TheList = []
for item in range(N):
    TheList.append([item + 1, np.random.random_sample(size=len(Accuracy_Indexes)).tolist()])

df = pd.DataFrame.from_records(TheList, columns=['Bin', 'Accuracies'])

df_long = df.explode('Accuracies', ignore_index=True)
df_long['Accuracy_Indexes'] = np.tile(Accuracy_Indexes, len(df))
df_long['Accuracies'] = df_long['Accuracies'].astype(float)

g = sns.catplot(x='Accuracy_Indexes', y='Accuracies', hue="Bin",
                data=df_long, height=5, aspect=.8)
plt.tight_layout()
plt.show()

catplot

From the same dataframe, you can also create a lineplot via

g = sns.relplot(x='Accuracy_Indexes', y='Accuracies', hue="Bin",
                data=df_long, height=5, aspect=.8, kind='line')

Note that relplot and catplot are "figure-level functions" which create a complete "figure" at once, with one or more subplots. If you want to create a single subplot, you can use the axis-level function lineplot. The approach is a bit different to control the figure size.

from matplotlib.ticker import MaxNLocator

plt.figure(figsize=(4, 5))
ax = sns.lineplot(x='Accuracy_Indexes', y='Accuracies', hue="Bin", data=df_long)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

lineplot

Upvotes: 3

Related Questions