Reputation: 57033
I would like to produce a scatter plot of pandas DataFrame with categorical row and column labels using matplotlib
. A sample DataFrame looks like this:
import pandas as pd
df = pd.DataFrame({"a": [1,2], "b": [3,4]}, index=["c","d"])
# a b
#c 1 2
#d 3 4
The marker size is the function of the respective DataFrame values. So far, I came up with an awkward solution that essentially enumerates the rows and columns, plots the data, and then reconstructs the labels:
flat = df.reset_index(drop=True).T.reset_index(drop=True).T.stack().reset_index()
# level_0 level_1 0
#0 0 0 1
#1 0 1 2
#2 1 0 3
#3 1 1 4
flat.plot(kind='scatter', x='level_0', y='level_1', s=100*flat[0])
plt.xticks(range(df.shape[1]), df.columns)
plt.yticks(range(df.shape[0]), df.index)
plt.show()
Now, question: Is there a more intuitive, more integrated way to produce this scatter plot, ideally without splitting the data and the metadata?
Upvotes: 8
Views: 2106
Reputation: 471
Maybe you can use numpy array and pd.melt to create the scatter plot as shown below:
arr = np.array([[i,j] for i in range(df.shape[1]) for j in range(df.shape[0])])
plt.scatter(arr[:,0],arr[:,1],s=100*pd.melt(df)['value'],marker='o')
plt.xlabel('level_0')
plt.ylabel('level_1')
plt.xticks(range(df.shape[1]), df.columns)
plt.yticks(range(df.shape[0]), df.index)
plt.show()
Upvotes: 3
Reputation: 8207
Maybe not the entire answer you're looking for, but an idea to help save time and readability with the flat=
line of code.
Pandas unstack method will produce a Series with a MultiIndex.
dfu = df.unstack()
print(dfu.index)
MultiIndex(levels=[[u'a', u'b'], [u'c', u'd']],
labels=[[0, 0, 1, 1], [0, 1, 0, 1]])
The MultiIndex contains contains the necessary x and y points to construct the plot (in labels
). Here, I assign levels
and labels
to more informative variable names better suited for plotting.
xlabels, ylabels = dfu.index.levels
xs, ys = dfu.index.labels
Plotting is pretty straight-forward from here.
plt.scatter(xs, ys, s=dfu*100)
plt.xticks(range(len(xlabels)), xlabels)
plt.yticks(range(len(ylabels)), ylabels)
plt.show()
I tried this on a few different DataFrame
shapes and it seemed to hold up.
Upvotes: 7
Reputation: 210842
It's not exactly what you were asking for, but it helps to visualize values in a similar way:
import seaborn as sns
sns.heatmap(df[::-1], annot=True)
Result:
Upvotes: 4