Reputation: 2367
I'd like to plot a convergence process of the MLE
algorithm with the plotly
library.
Requirements:
A plot of a single iteration may be produced by Code 1
, with the desired output shown in Figure 1
:
Code 1
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure()
for i in range(5):
fig.add_trace(
go.Scatter(
x=A[i:i+3][:, 0],
y=A[i:i+3][:, 1],
mode='markers',
name=f'cluster {i+1}',
marker_color=colors[i]
)
)
for c in clusters:
fig.add_trace(
go.Scatter(
x=[centroids[c-1][0]],
y=[centroids[c-1][1]],
name=f'centroid of cluster {c}',
mode='markers',
marker_color=colors[c-1],
marker_symbol='x'
)
)
fig.show()
Figure 1
I've seen this tutorial, but it seems that you can plot only a single trace in a graph_objects.Frame()
, and Code 2
represents a simple example for producing an animated scatter plot of all the points, where each frame plots points from different cluster and the centroids:
Code 2
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure(
data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 1', marker_color=colors[0])],
layout=go.Layout(
xaxis=dict(range=[-10, 10], autorange=False),
yaxis=dict(range=[-10, 10], autorange=False),
title="Start Title",
updatemenus=[dict(
type="buttons",
buttons=[dict(label="Play",
method="animate",
args=[None])])]
),
frames=[go.Frame(data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 2', marker_color=colors[1])]),
go.Frame(data=[go.Scatter(x=A[3:5][:,0], y=A[3:5][:,1], mode='markers', name='cluster 3', marker_color=colors[2])]),
go.Frame(data=[go.Scatter(x=A[5:8][:,0], y=A[5:8][:,1], mode='markers', name='cluster 4', marker_color=colors[3])]),
go.Frame(data=[go.Scatter(x=A[8:][:,0], y=A[8:][:,1], mode='markers', name='cluster 5', marker_color=colors[4])]),
go.Frame(data=[go.Scatter(x=[centroids[0][0]], y=[centroids[0][1]], mode='markers', name='centroid of cluster 1', marker_color=colors[0], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[1][0]], y=[centroids[1][1]], mode='markers', name='centroid of cluster 2', marker_color=colors[1], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[2][0]], y=[centroids[2][1]], mode='markers', name='centroid of cluster 3', marker_color=colors[2], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[3][0]], y=[centroids[3][1]], mode='markers', name='centroid of cluster 4', marker_color=colors[3], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[4][0]], y=[centroids[4][1]], mode='markers', name='centroid of cluster 5', marker_color=colors[4], marker_symbol='x')])]
)
fig.show()
Why does Code 2 does not fit my needs:
Code 2
in a single frame each iteration of the algorithm (i.e. each frame of the desired solution will look like Figure 1
)What I have tried:
graph_objects.Figure()
, and adding it to a graph_objects.Frame()
as shown in Code 3
, but have gotten Error 1
.Code 3:
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure()
for i in range(5):
fig.add_trace(
go.Scatter(
x=A[i:i+3][:, 0],
y=A[i:i+3][:, 1],
mode='markers',
name=f'cluster {i+1}',
marker_color=colors[i]
)
)
for c in clusters:
fig.add_trace(
go.Scatter(
x=[centroids[c-1][0]],
y=[centroids[c-1][1]],
name=f'centroid of cluster {c}',
mode='markers',
marker_color=colors[c-1],
marker_symbol='x'
)
)
animated_fig = go.Figure(
data=[go.Scatter(x=A[:3][:, 0], y=A[:3][:, 1], mode='markers', name=f'cluster 0', marker_color=colors[0])],
layout=go.Layout(
xaxis=dict(range=[-10, 10], autorange=False),
yaxis=dict(range=[-10, 10], autorange=False),
title="Start Title",
updatemenus=[dict(
type="buttons",
buttons=[dict(label="Play",
method="animate",
args=[None])])]
),
frames=[go.Frame(data=[fig])]
)
animated_fig.show()
Error 1:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-681-11264f38e6f7> in <module>
43 args=[None])])]
44 ),
---> 45 frames=[go.Frame(data=[fig])]
46 )
47
~\Anaconda3\lib\site-packages\plotly\graph_objs\_frame.py in __init__(self, arg, baseframe, data, group, layout, name, traces, **kwargs)
241 _v = data if data is not None else _v
242 if _v is not None:
--> 243 self["data"] = _v
244 _v = arg.pop("group", None)
245 _v = group if group is not None else _v
~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in __setitem__(self, prop, value)
3973 # ### Handle compound array property ###
3974 elif isinstance(validator, (CompoundArrayValidator, BaseDataValidator)):
-> 3975 self._set_array_prop(prop, value)
3976
3977 # ### Handle simple property ###
~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in _set_array_prop(self, prop, val)
4428 # ------------
4429 validator = self._get_validator(prop)
-> 4430 val = validator.validate_coerce(val, skip_invalid=self._skip_invalid)
4431
4432 # Save deep copies of current and new states
~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in validate_coerce(self, v, skip_invalid, _validate)
2671
2672 if invalid_els:
-> 2673 self.raise_invalid_elements(invalid_els)
2674
2675 v = to_scalar_or_list(res)
~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in raise_invalid_elements(self, invalid_els)
298 pname=self.parent_name,
299 invalid=invalid_els[:10],
--> 300 valid_clr_desc=self.description(),
301 )
302 )
ValueError:
Invalid element(s) received for the 'data' property of frame
Invalid elements include: [Figure({
'data': [{'marker': {'color': 'red'},
'mode': 'markers',
'name': 'cluster 1',
'type': 'scatter',
'x': array([-1.30634452, -1.73005459, 0.58746435]),
'y': array([ 0.15388112, 0.47452796, -1.86354483])},
{'marker': {'color': 'green'},
'mode': 'markers',
'name': 'cluster 2',
'type': 'scatter',
'x': array([-1.73005459, 0.58746435, -0.27492892]),
'y': array([ 0.47452796, -1.86354483, -0.20329897])},
{'marker': {'color': 'blue'},
'mode': 'markers',
'name': 'cluster 3',
'type': 'scatter',
'x': array([ 0.58746435, -0.27492892, 0.21002816]),
'y': array([-1.86354483, -0.20329897, 1.99487636])},
{'marker': {'color': 'yellow'},
'mode': 'markers',
'name': 'cluster 4',
'type': 'scatter',
'x': array([-0.27492892, 0.21002816, -0.0148647 ]),
'y': array([-0.20329897, 1.99487636, 0.73484184])},
{'marker': {'color': 'magenta'},
'mode': 'markers',
'name': 'cluster 5',
'type': 'scatter',
'x': array([ 0.21002816, -0.0148647 , 1.13589386]),
'y': array([1.99487636, 0.73484184, 2.08810809])},
{'marker': {'color': 'red', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 1',
'type': 'scatter',
'x': [9],
'y': [6]},
{'marker': {'color': 'green', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 2',
'type': 'scatter',
'x': [0],
'y': [5]},
{'marker': {'color': 'blue', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 3',
'type': 'scatter',
'x': [8],
'y': [6]},
{'marker': {'color': 'yellow', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 4',
'type': 'scatter',
'x': [7],
'y': [1]},
{'marker': {'color': 'magenta', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 5',
'type': 'scatter',
'x': [6],
'y': [2]}],
'layout': {'template': '...'}
})]
The 'data' property is a tuple of trace instances
that may be specified as:
- A list or tuple of trace instances
(e.g. [Scatter(...), Bar(...)])
- A single trace instance
(e.g. Scatter(...), Bar(...), etc.)
- A list or tuple of dicts of string/value properties where:
- The 'type' property specifies the trace type
One of: ['area', 'bar', 'barpolar', 'box',
'candlestick', 'carpet', 'choropleth',
'choroplethmapbox', 'cone', 'contour',
'contourcarpet', 'densitymapbox', 'funnel',
'funnelarea', 'heatmap', 'heatmapgl',
'histogram', 'histogram2d',
'histogram2dcontour', 'image', 'indicator',
'isosurface', 'mesh3d', 'ohlc', 'parcats',
'parcoords', 'pie', 'pointcloud', 'sankey',
'scatter', 'scatter3d', 'scattercarpet',
'scattergeo', 'scattergl', 'scattermapbox',
'scatterpolar', 'scatterpolargl',
'scatterternary', 'splom', 'streamtube',
'sunburst', 'surface', 'table', 'treemap',
'violin', 'volume', 'waterfall']
- All remaining properties are passed to the constructor of
the specified trace type
(e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])
plotly.express
module, as shown in Code 3
, but the only thing that is missing there is for the centroids to be marked as x
s.Code 3:
import plotly.express as px
import numpy as np
import pandas as pd
A = np.random.randn(200).reshape((100, 2))
iteration = np.array([1, 2, 3, 4, 5]).repeat(20)
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = np.random.randint(1, 6, size=100)
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
df = pd.DataFrame(dict(x1=A[:, 0], x2=A[:, 1], type='point', cluster=pd.Series(clusters, dtype='str'), iteration=iteration))
centroid_df = pd.DataFrame(dict(x1=centroids[:, 0], x2=centroids[:, 1], type='centroid', cluster=[1, 2, 3, 4, 5], iteration=[1, 2, 3, 4, 5]))
df = df.append(centroid_df, ignore_index=True)
px.scatter(df, x="x1", y="x2", animation_frame="iteration", color="cluster", hover_name="cluster", range_x=[-10,10], range_y=[-10,10])
I'd appreciate any help for achieving the desired result. Thanks.
Upvotes: 4
Views: 6679
Reputation: 13437
You can add two traces per frame but apparently you need to define these two traces in the first data
too. I added again the first two traces as a frame in order to have them visible in subsequent play. Here the full code
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure(
data=[go.Scatter(x=A[:3][:,0],
y=A[:3][:,1],
mode='markers',
name='cluster 1',
marker_color=colors[0]),
go.Scatter(x=[centroids[0][0]],
y=[centroids[0][1]],
mode='markers',
name='centroid of cluster 1',
marker_color=colors[0],
marker_symbol='x')
],
layout=go.Layout(
xaxis=dict(range=[-10, 10], autorange=False),
yaxis=dict(range=[-10, 10], autorange=False),
title="Start Title",
updatemenus=[dict(
type="buttons",
buttons=[dict(label="Play",
method="animate",
args=[None]),
dict(label="Pause",
method="animate",
args=[None,
{"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0}}],
)])]
),
frames=[
go.Frame(
data=[go.Scatter(x=A[:3][:,0],
y=A[:3][:,1],
mode='markers',
name='cluster 1',
marker_color=colors[0]),
go.Scatter(x=[centroids[0][0]],
y=[centroids[0][1]],
mode='markers',
name='centroid of cluster 1',
marker_color=colors[0],
marker_symbol='x')
]),
go.Frame(
data=[
go.Scatter(x=A[:3][:,0],
y=A[:3][:,1],
mode='markers',
name='cluster 2',
marker_color=colors[1]),
go.Scatter(x=[centroids[1][0]],
y=[centroids[1][1]],
mode='markers',
name='centroid of cluster 2',
marker_color=colors[1],
marker_symbol='x')
]),
go.Frame(
data=[
go.Scatter(x=A[3:5][:,0],
y=A[3:5][:,1],
mode='markers',
name='cluster 3',
marker_color=colors[2]),
go.Scatter(x=[centroids[2][0]],
y=[centroids[2][1]],
mode='markers',
name='centroid of cluster 3',
marker_color=colors[2],
marker_symbol='x')
]),
go.Frame(
data=[
go.Scatter(x=A[5:8][:,0],
y=A[5:8][:,1],
mode='markers',
name='cluster 4',
marker_color=colors[3]),
go.Scatter(x=[centroids[3][0]],
y=[centroids[3][1]],
mode='markers',
name='centroid of cluster 4',
marker_color=colors[3],
marker_symbol='x')]),
go.Frame(
data=[
go.Scatter(x=A[8:][:,0],
y=A[8:][:,1],
mode='markers',
name='cluster 5',
marker_color=colors[4]),
go.Scatter(x=[centroids[4][0]],
y=[centroids[4][1]],
mode='markers',
name='centroid of cluster 5',
marker_color=colors[4],
marker_symbol='x')
]),
])
fig.show()
Upvotes: 5