Reputation: 305
I like to use Plotly to visualize everything, I'm trying to visualize a confusion matrix by Plotly, this is my code:
def plot_confusion_matrix(y_true, y_pred, class_names):
confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
confusion_matrix = confusion_matrix.astype(int)
layout = {
"title": "Confusion Matrix",
"xaxis": {"title": "Predicted value"},
"yaxis": {"title": "Real value"}
}
fig = go.Figure(data=go.Heatmap(z=confusion_matrix,
x=class_names,
y=class_names,
hoverongaps=False),
layout=layout)
fig.show()
and the result is
How can I show the number inside corresponding cell instead of hovering, like this
Upvotes: 10
Views: 26336
Reputation: 1
My code:
import plotly.express as px
from sklearn.metrics import confusion_matrix
def display_confusion_matrix(y_true, y_pred, labels=[0, 1]):
cm = confusion_matrix(y_true, y_pred)
dims = [str(l) for l in labels]
fig = px.imshow(cm,
x=dims,
y=dims,
color_continuous_scale='Reds',
aspect="auto")
fig.update_traces(text=cm, texttemplate="%{text}")
fig.update_layout(title="Confusion matrix",
xaxis_title='Predicted',
yaxis_title='True',
dragmode='select',
width=500,
height=500,
hovermode='closest',
template='seaborn')
fig.show()
Upvotes: 0
Reputation: 81
I found @vestland's strategy to be the most useful.
However, unlike a traditional confusion matrix, the correct model predictions are along the upper-right diagonal, not the upper-left.
This can easily be fixed by inverting all index values of the confusion matrix such as shown below:
import plotly.figure_factory as ff
z = [[0.1, 0.3, 0.5, 0.2],
[1.0, 0.8, 0.6, 0.1],
[0.1, 0.3, 0.6, 0.9],
[0.6, 0.4, 0.2, 0.2]]
# invert z idx values
z = z[::-1]
x = ['healthy', 'multiple diseases', 'rust', 'scab']
y = x[::-1].copy() # invert idx values of x
# change each element of z to type string for annotations
z_text = [[str(y) for y in x] for x in z]
# set up figure
fig = ff.create_annotated_heatmap(z, x=x, y=y, annotation_text=z_text, colorscale='Viridis')
# add title
fig.update_layout(title_text='<i><b>Confusion matrix</b></i>',
#xaxis = dict(title='x'),
#yaxis = dict(title='x')
)
# add custom xaxis title
fig.add_annotation(dict(font=dict(color="black",size=14),
x=0.5,
y=-0.15,
showarrow=False,
text="Predicted value",
xref="paper",
yref="paper"))
# add custom yaxis title
fig.add_annotation(dict(font=dict(color="black",size=14),
x=-0.35,
y=0.5,
showarrow=False,
text="Real value",
textangle=-90,
xref="paper",
yref="paper"))
# adjust margins to make room for yaxis title
fig.update_layout(margin=dict(t=50, l=200))
# add colorbar
fig['data'][0]['showscale'] = True
fig.show()
Upvotes: 8
Reputation: 61094
You can use annotated heatmaps with ff.create_annotated_heatmap()
to get this:
Complete code:
import plotly.figure_factory as ff
z = [[0.1, 0.3, 0.5, 0.2],
[1.0, 0.8, 0.6, 0.1],
[0.1, 0.3, 0.6, 0.9],
[0.6, 0.4, 0.2, 0.2]]
x = ['healthy', 'multiple diseases', 'rust', 'scab']
y = ['healthy', 'multiple diseases', 'rust', 'scab']
# change each element of z to type string for annotations
z_text = [[str(y) for y in x] for x in z]
# set up figure
fig = ff.create_annotated_heatmap(z, x=x, y=y, annotation_text=z_text, colorscale='Viridis')
# add title
fig.update_layout(title_text='<i><b>Confusion matrix</b></i>',
#xaxis = dict(title='x'),
#yaxis = dict(title='x')
)
# add custom xaxis title
fig.add_annotation(dict(font=dict(color="black",size=14),
x=0.5,
y=-0.15,
showarrow=False,
text="Predicted value",
xref="paper",
yref="paper"))
# add custom yaxis title
fig.add_annotation(dict(font=dict(color="black",size=14),
x=-0.35,
y=0.5,
showarrow=False,
text="Real value",
textangle=-90,
xref="paper",
yref="paper"))
# adjust margins to make room for yaxis title
fig.update_layout(margin=dict(t=50, l=200))
# add colorbar
fig['data'][0]['showscale'] = True
fig.show()
Upvotes: 18
Reputation: 256
As @vestland say you can annotate figure with plotly. The heatmap works as any kind of plotly Figure. Here's a code for plotting heatmap from a confusion matrix (basically just a 2-d vector with numbers).
def plot_confusion_matrix(cm, labels, title):
# cm : confusion matrix list(list)
# labels : name of the data list(str)
# title : title for the heatmap
data = go.Heatmap(z=cm, y=labels, x=labels)
annotations = []
for i, row in enumerate(cm):
for j, value in enumerate(row):
annotations.append(
{
"x": labels[i],
"y": labels[j],
"font": {"color": "white"},
"text": str(value),
"xref": "x1",
"yref": "y1",
"showarrow": False
}
)
layout = {
"title": title,
"xaxis": {"title": "Predicted value"},
"yaxis": {"title": "Real value"},
"annotations": annotations
}
fig = go.Figure(data=data, layout=layout)
return fig
Upvotes: 6