Piotr Rarus
Piotr Rarus

Reputation: 952

Animated scatter plot over surface (Plotly)

I'm trying to make an animated scatter plot over fixed surface using plotly.

This is a code I use to draw the surface:

import plotly.graph_objects as go


def surface(x, y, z, opacity: float = 1.0) -> go.Figure:

    fig = go.Figure()

    fig.add_trace(
        go.Surface(
            x=x,
            y=y,
            z=z,
            contours_z=dict(
                show=True,
                usecolormap=True,
                project_z=True,
            ),
            opacity=opacity
        )
    )

    return fig

Surface plot

Then I'm trying to overlay scatter plots over it.


def population(
    self,
    benchmark: CEC2013,
    batch_stats: BatchStats,
    filename: str = 'population'
):
    # Here I'm creating the surface figure using previous method
    surface = figure.surface.surface(*benchmark.surface, opacity=0.8)

    frames = []

    # Time to add some frames using Scatter 3D
    for population in batch_stats.population_history:
        x = [solution.genome[0] for solution in population.solutions]
        y = [solution.genome[1] for solution in population.solutions]
        fitness = population.fitness

        frame = go.Frame(
            data=[
                go.Scatter3d(
                    x=x,
                    y=y,
                    z=fitness,
                    mode='markers',
                    marker=dict(
                        size=6,
                        color='#52CA34',
                    )
                )
            ]
        )

        frames.append(frame)

    # Update frames to root figure
    surface.frames = frames

    # just a fancy display, to make it work in offline mode in html render from notebook
    pyo.iplot(
        surface,
        filename=filename,
        image_width=self.IMAGE_WIDTH,
        image_height=self.IMAGE_HEIGHT
    )

Surface displays only in a first frame. Scatter plots are displayed in subsequent frames, but without underlaying surface.

Scatter plot

Code is located here on a dev branch. There's debug notebook named test_vis.ipynb at the root. Thanks for help <3

Upvotes: 1

Views: 3280

Answers (1)

Piotr Rarus
Piotr Rarus

Reputation: 952

I've posted this question as an issue at plotly's repository.

Here's an answer I've received.

@empet - thank you <3

When your fig.data contains only one trace then it is supposed that each frame updates that trace and the trace is no more displayed during the animation. That's why you must define:

fig = go.Figure(
    data=[
        go.Scatter(
            y=y,
            x=x,
            mode="lines",
            ine_shape='spline'

         )
    ]*2
)

i.e. include the same trace twice in fig.data. At the same time modify the frame definition as follows:

 frame = go.Frame(data=[scatter], traces=[1])

traces = [1] informs plotly.js that each frame updates the trace fig.data[1], while fig.data[0] is unchanged during the animation.

Here's a complete example of how to make an animation over some based plot.

import math
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import plotly.offline as pyo


class Plot:

    def __init__(
        self,
        image_width: int = 1200,
        image_height: int = 900
    ) -> None:

        self.IMAGE_WIDTH = image_width
        self.IMAGE_HEIGHT = image_height

        pyo.init_notebook_mode(connected=False)
        pio.renderers.default = 'notebook'

    def population(self, filename: str = 'population'):

        x_spline = np.linspace(
            start=0,
            stop=20,
            num=100,
            endpoint=True
        )

        y_spline = np.array([math.sin(x_i) for x_i in x_spline])

        x_min = np.min(x_spline)
        x_max = np.max(x_spline)

        y_min = np.min(y_spline)
        y_max = np.max(y_spline)

        spline = go.Scatter(
            y=y_spline,
            x=x_spline,
            mode="lines",
            line_shape='spline'
        )

        fig = go.Figure(
            data=[spline] * 2
        )

        frames = []

        for i in range(50):
            x = np.random.random_sample(size=5)
            x *= x_max

            y = np.array([math.sin(x_i) for x_i in x])

            scatter = go.Scatter(
                x=x,
                y=y,
                mode='markers',
                marker=dict(
                    color='Green',
                    size=12,
                    line=dict(
                        color='Red',
                        width=2
                    )
                ),
            )
            frame = go.Frame(data=[scatter], traces=[1])
            frames.append(frame)

        fig.frames = frames

        fig.layout = go.Layout(
            xaxis=dict(
                range=[x_min, x_max],
                autorange=False
            ),
            yaxis=dict(
                range=[y_min, y_max],
                autorange=False
            ),
            title="Start Title",
            updatemenus=[
                dict(
                    type="buttons",
                    buttons=[
                        dict(
                            label="Play",
                            method="animate",
                            args=[None]
                        )
                    ]
                )
            ]
        )

        fig.update_layout(
            xaxis_title='x',
            yaxis_title='y',
            title='Fitness landscape',
            # autosize=True
        )

        pyo.iplot(
            fig,
            filename=filename,
            image_width=self.IMAGE_WIDTH,
            image_height=self.IMAGE_HEIGHT
        )

plot = Plot()
plot.population()

Upvotes: 3

Related Questions