Michael
Michael

Reputation: 2367

Animated plot with `plotly`

I'd like to plot a convergence process of the MLE algorithm with the plotly library.

Requirements:

A plot of a single iteration may be produced by Code 1, with the desired output shown in Figure 1:

Code 1

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure()
for i in range(5):
    fig.add_trace(
        go.Scatter(
            x=A[i:i+3][:, 0],
            y=A[i:i+3][:, 1],
            mode='markers',
            name=f'cluster {i+1}',
            marker_color=colors[i]
        )
    )
    
for c in clusters:
    fig.add_trace(
        go.Scatter(
            x=[centroids[c-1][0]],
            y=[centroids[c-1][1]],
            name=f'centroid of cluster {c}',
            mode='markers',
            marker_color=colors[c-1],
            marker_symbol='x'
        )
    )
fig.show()

Figure 1

Figure 1

I've seen this tutorial, but it seems that you can plot only a single trace in a graph_objects.Frame(), and Code 2 represents a simple example for producing an animated scatter plot of all the points, where each frame plots points from different cluster and the centroids:

Code 2

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure(
    data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 1', marker_color=colors[0])],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 2', marker_color=colors[1])]),
            go.Frame(data=[go.Scatter(x=A[3:5][:,0], y=A[3:5][:,1], mode='markers', name='cluster 3', marker_color=colors[2])]),
            go.Frame(data=[go.Scatter(x=A[5:8][:,0], y=A[5:8][:,1], mode='markers', name='cluster 4', marker_color=colors[3])]),
            go.Frame(data=[go.Scatter(x=A[8:][:,0], y=A[8:][:,1], mode='markers', name='cluster 5', marker_color=colors[4])]),
            go.Frame(data=[go.Scatter(x=[centroids[0][0]], y=[centroids[0][1]], mode='markers', name='centroid of cluster 1', marker_color=colors[0], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[1][0]], y=[centroids[1][1]], mode='markers', name='centroid of cluster 2', marker_color=colors[1], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[2][0]], y=[centroids[2][1]], mode='markers', name='centroid of cluster 3', marker_color=colors[2], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[3][0]], y=[centroids[3][1]], mode='markers', name='centroid of cluster 4', marker_color=colors[3], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[4][0]], y=[centroids[4][1]], mode='markers', name='centroid of cluster 5', marker_color=colors[4], marker_symbol='x')])]
)
fig.show()

Why does Code 2 does not fit my needs:

What I have tried:

Code 3:

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure()
for i in range(5):
    fig.add_trace(
        go.Scatter(
            x=A[i:i+3][:, 0],
            y=A[i:i+3][:, 1],
            mode='markers',
            name=f'cluster {i+1}',
            marker_color=colors[i]
        )
    )
    
for c in clusters:
    fig.add_trace(
        go.Scatter(
            x=[centroids[c-1][0]],
            y=[centroids[c-1][1]],
            name=f'centroid of cluster {c}',
            mode='markers',
            marker_color=colors[c-1],
            marker_symbol='x'
        )
    )

animated_fig = go.Figure(
    data=[go.Scatter(x=A[:3][:, 0], y=A[:3][:, 1], mode='markers', name=f'cluster 0', marker_color=colors[0])],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[fig])]
)

animated_fig.show()

Error 1:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-681-11264f38e6f7> in <module>
     43                           args=[None])])]
     44     ),
---> 45     frames=[go.Frame(data=[fig])]
     46 )
     47 

~\Anaconda3\lib\site-packages\plotly\graph_objs\_frame.py in __init__(self, arg, baseframe, data, group, layout, name, traces, **kwargs)
    241         _v = data if data is not None else _v
    242         if _v is not None:
--> 243             self["data"] = _v
    244         _v = arg.pop("group", None)
    245         _v = group if group is not None else _v

~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in __setitem__(self, prop, value)
   3973                 # ### Handle compound array property ###
   3974                 elif isinstance(validator, (CompoundArrayValidator, BaseDataValidator)):
-> 3975                     self._set_array_prop(prop, value)
   3976 
   3977                 # ### Handle simple property ###

~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in _set_array_prop(self, prop, val)
   4428         # ------------
   4429         validator = self._get_validator(prop)
-> 4430         val = validator.validate_coerce(val, skip_invalid=self._skip_invalid)
   4431 
   4432         # Save deep copies of current and new states

~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in validate_coerce(self, v, skip_invalid, _validate)
   2671 
   2672             if invalid_els:
-> 2673                 self.raise_invalid_elements(invalid_els)
   2674 
   2675             v = to_scalar_or_list(res)

~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in raise_invalid_elements(self, invalid_els)
    298                     pname=self.parent_name,
    299                     invalid=invalid_els[:10],
--> 300                     valid_clr_desc=self.description(),
    301                 )
    302             )

ValueError: 
    Invalid element(s) received for the 'data' property of frame
        Invalid elements include: [Figure({
    'data': [{'marker': {'color': 'red'},
              'mode': 'markers',
              'name': 'cluster 1',
              'type': 'scatter',
              'x': array([-1.30634452, -1.73005459,  0.58746435]),
              'y': array([ 0.15388112,  0.47452796, -1.86354483])},
             {'marker': {'color': 'green'},
              'mode': 'markers',
              'name': 'cluster 2',
              'type': 'scatter',
              'x': array([-1.73005459,  0.58746435, -0.27492892]),
              'y': array([ 0.47452796, -1.86354483, -0.20329897])},
             {'marker': {'color': 'blue'},
              'mode': 'markers',
              'name': 'cluster 3',
              'type': 'scatter',
              'x': array([ 0.58746435, -0.27492892,  0.21002816]),
              'y': array([-1.86354483, -0.20329897,  1.99487636])},
             {'marker': {'color': 'yellow'},
              'mode': 'markers',
              'name': 'cluster 4',
              'type': 'scatter',
              'x': array([-0.27492892,  0.21002816, -0.0148647 ]),
              'y': array([-0.20329897,  1.99487636,  0.73484184])},
             {'marker': {'color': 'magenta'},
              'mode': 'markers',
              'name': 'cluster 5',
              'type': 'scatter',
              'x': array([ 0.21002816, -0.0148647 ,  1.13589386]),
              'y': array([1.99487636, 0.73484184, 2.08810809])},
             {'marker': {'color': 'red', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 1',
              'type': 'scatter',
              'x': [9],
              'y': [6]},
             {'marker': {'color': 'green', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 2',
              'type': 'scatter',
              'x': [0],
              'y': [5]},
             {'marker': {'color': 'blue', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 3',
              'type': 'scatter',
              'x': [8],
              'y': [6]},
             {'marker': {'color': 'yellow', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 4',
              'type': 'scatter',
              'x': [7],
              'y': [1]},
             {'marker': {'color': 'magenta', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 5',
              'type': 'scatter',
              'x': [6],
              'y': [2]}],
    'layout': {'template': '...'}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['area', 'bar', 'barpolar', 'box',
                     'candlestick', 'carpet', 'choropleth',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymapbox', 'funnel',
                     'funnelarea', 'heatmap', 'heatmapgl',
                     'histogram', 'histogram2d',
                     'histogram2dcontour', 'image', 'indicator',
                     'isosurface', 'mesh3d', 'ohlc', 'parcats',
                     'parcoords', 'pie', 'pointcloud', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermapbox',
                     'scatterpolar', 'scatterpolargl',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])

Code 3:

import plotly.express as px
import numpy as np
import pandas as pd

A = np.random.randn(200).reshape((100, 2))
iteration = np.array([1, 2, 3, 4, 5]).repeat(20)
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = np.random.randint(1, 6, size=100)
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

df = pd.DataFrame(dict(x1=A[:, 0], x2=A[:, 1], type='point', cluster=pd.Series(clusters, dtype='str'), iteration=iteration))
centroid_df = pd.DataFrame(dict(x1=centroids[:, 0], x2=centroids[:, 1], type='centroid', cluster=[1, 2, 3, 4, 5], iteration=[1, 2, 3, 4, 5]))
df = df.append(centroid_df, ignore_index=True)
px.scatter(df, x="x1", y="x2", animation_frame="iteration", color="cluster", hover_name="cluster", range_x=[-10,10], range_y=[-10,10])

I'd appreciate any help for achieving the desired result. Thanks.

Upvotes: 4

Views: 6679

Answers (1)

rpanai
rpanai

Reputation: 13437

You can add two traces per frame but apparently you need to define these two traces in the first data too. I added again the first two traces as a frame in order to have them visible in subsequent play. Here the full code

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure(
    data=[go.Scatter(x=A[:3][:,0],
                     y=A[:3][:,1],
                     mode='markers',
                     name='cluster 1',
                     marker_color=colors[0]),
          go.Scatter(x=[centroids[0][0]],
                     y=[centroids[0][1]],
                     mode='markers',
                     name='centroid of cluster 1',
                     marker_color=colors[0],
                     marker_symbol='x')
         ],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None]),
                     dict(label="Pause",
                          method="animate",
                          args=[None,
                               {"frame": {"duration": 0, "redraw": False},
                                "mode": "immediate",
                                "transition": {"duration": 0}}],
                         )])]
    ),
    frames=[
    go.Frame(
    data=[go.Scatter(x=A[:3][:,0],
                     y=A[:3][:,1],
                     mode='markers',
                     name='cluster 1',
                     marker_color=colors[0]),
          go.Scatter(x=[centroids[0][0]],
                     y=[centroids[0][1]],
                     mode='markers',
                     name='centroid of cluster 1',
                     marker_color=colors[0],
                     marker_symbol='x')
         ]),
    go.Frame(
        data=[
            go.Scatter(x=A[:3][:,0],
                       y=A[:3][:,1],
                       mode='markers',
                       name='cluster 2',
                       marker_color=colors[1]),
            go.Scatter(x=[centroids[1][0]],
                       y=[centroids[1][1]],
                       mode='markers',
                       name='centroid of cluster 2',
                       marker_color=colors[1],
                       marker_symbol='x')
        ]),
    go.Frame(
        data=[
            go.Scatter(x=A[3:5][:,0],
                       y=A[3:5][:,1],
                       mode='markers',
                       name='cluster 3',
                       marker_color=colors[2]),
            go.Scatter(x=[centroids[2][0]],
                       y=[centroids[2][1]],
                       mode='markers',
                       name='centroid of cluster 3',
                       marker_color=colors[2],
                       marker_symbol='x')
        ]),
    go.Frame(
        data=[
            go.Scatter(x=A[5:8][:,0],
                       y=A[5:8][:,1],
                       mode='markers',
                       name='cluster 4',
                       marker_color=colors[3]),
        go.Scatter(x=[centroids[3][0]],
                   y=[centroids[3][1]],
                   mode='markers',
                   name='centroid of cluster 4',
                   marker_color=colors[3],
                   marker_symbol='x')]),
    go.Frame(
        data=[
            go.Scatter(x=A[8:][:,0],
                       y=A[8:][:,1],
                       mode='markers',
                       name='cluster 5',
                       marker_color=colors[4]),
            go.Scatter(x=[centroids[4][0]],
                       y=[centroids[4][1]],
                       mode='markers',
                       name='centroid of cluster 5',
                       marker_color=colors[4],
                       marker_symbol='x')
        ]),
    ])
            
fig.show()

enter image description here

Upvotes: 5

Related Questions