Reputation: 11
I have created a treemap with squarify and now Im trying to create a legend to display the data in the treemap next to the graph.
The built in legend function is not generating the legend I want (it is currently displaying the first column of my data frame and the indexes of each row) so I've been trying to play around with it without success. I would like the legend to be:
SKU Volume in Units
a 1
b 2
c 3
d 4
e 5
f 6
g 7
h 20
import matplotlib as mpl
import squarify
import matplotlib.cm
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
#create figure
fig = plt.gcf()
fig.set_size_inches(16, 5)
#set data values
data = [['a',1],['b',2],['c',3],['d',4],['e',5],['f',6],['g',7],['h',20]]
data_slice = pd.DataFrame(data, columns=['Atom SKU Code','Total Volume'])
print(data_slice)
#create color set
norm = mpl.colors.Normalize(vmin=min(data_slice['Total Volume']), vmax=max(data_slice['Total Volume']))
colors = [mpl.cm.BuGn(norm(value)) for value in data_slice['Total Volume']]
#plot figure
ax1 = squarify.plot(label=data_slice['Atom SKU Code'], sizes=data_slice['Total Volume'], color=colors, alpha=.6)
plt.title("Volume by SKU (Units Sold)", fontsize=23, fontweight="bold")
plt.axis('off')
plt.legend(title='SKU Volume in Units', loc='center left',bbox_to_anchor=(1, 0.5),frameon=False)
plt.tight_layout()
plt.show()
Upvotes: 0
Views: 422
Reputation: 11
Figured out a way to do this. I removed the label and used plt.table to create a my version of a legend. Also leveraged the colors I had created to make the legend look nicer.
picture: output with legend table
def genereate_legend_table(ax,colors,no_of_skus_to_graph,data_slice):
# ___________________________________________________________________________________
# DESCRIPTION
# This functions generates a legend table to be plotted with the treemap plots
#------------------------------------------------------------------------------------
# ARGUMENTS
# ax: subplot axis where the table is to be plotted
# colors: color list object generated by matplotlibs color library
# no_of_skus_to_graph: how many skus are being represented in the tree map
# data_slice: the data to be written in the table
# ___________________________________________________________________________________
# Create hex color list from normalized color object
hex_list = []
for n in range(len(colors)):
hex_list.append(mpl.colors.to_hex(colors[n]))
# Create table object on desired axis
legend_table = ax.table(cellText=data_slice.values,
colLabels=data_slice.columns,
loc='right',
colLoc='right',
colWidths=[0.2, 0.2],
edges='')
# Change table text color
i = 0
while i <= no_of_skus_to_graph - 1:
legend_table[(i + 1, 0)].get_text().set_color(hex_list[i])
legend_table[(i + 1, 1)].get_text().set_color(hex_list[i])
s = '¥{:,.2f}'.format(float(legend_table[(i + 1, 1)].get_text().get_text()))
legend_table[(i + 1, 1)].get_text().set_text(s)
i = i + 1
# Return table
return legend_table
Upvotes: 1