AG_exp
AG_exp

Reputation: 129

How to animate a straight in plotly?

I want to visualize the process of fitting a straight line to a 2D dataset. For every epoch, I have the x and y coordinates of the starting point and the endpoint of the straight. Below is an example image of a dataset with a line (matplotlib).

enter image description here

I saw that plotly provides an option to animate lines.

enter image description here

The above plot was created with code similar to this:

# this should animate my line
fig_data = px.line(df, x="xxx", y="yyy", animation_frame="epoch", animation_group="name", title="fitted line")

# responsible for the red scatter plot points:
fig_data.add_traces(
    go.Scatter(
        x=xxx, 
        y=yyy, mode='markers', name='House in Dataset')
)

The dataframe looks like that:

epoch       xxx                                       yyy      name
0       0  [0.5, 4]   [1.4451884285714285, 4.730202428571428]  example
1       1  [0.5, 4]  [1.3944818842653062, 4.4811159469795925]  example
2       2  [0.5, 4]   [1.3475661354539474, 4.251154573663417]  example
3       3  [0.5, 4]    [1.3041510122346094, 4.03885377143571]  example

So the line that should be shown in epoch 0 starts from (0.5,1.44) and goes to (4,4.73). However, no line is rendered. What should I change?

Upvotes: 3

Views: 158

Answers (1)

jayveesea
jayveesea

Reputation: 3199

I believe the problem is with the format of xxx and yyy in df, as they are now they're nested arrays that don't seem to be referenced correctly by px.line.

You can use pd.Series.explode to "flatten" this data frame (example here) and then use that as the input to px.line. See here for more info on panda's explode.

Using xdf=df.set_index('epoch').apply(pd.Series.explode).reset_index() will yield:

epoch   xxx      yyy       name
0   0   0.5 1.445188    example
1   0   4   4.730202    example
2   1   0.5 1.394482    example
3   1   4   4.481116    example
4   2   0.5 1.347566    example
5   2   4   4.251155    example
6   3   0.5 1.304151    example
7   3   4   4.038854    example

Full example with comments:

import plotly.express as px
import pandas as pd

data = {'epoch': [0,1,2,3], 
        'xxx': [[0.5, 4], [0.5, 4], [0.5, 4], [0.5, 4]],
        'yyy': [[1.4451884285714285, 4.7302024285714280], 
                [1.3944818842653062, 4.4811159469795925], 
                [1.3475661354539474, 4.2511545736634170], 
                [1.3041510122346094, 4.0388537714357100]],
       'name':['example','example','example','example']}  
df = pd.DataFrame.from_dict(data)

# now exploding `df`
xdf=df.set_index('epoch').apply(pd.Series.explode).reset_index()

# now plotting using xdf as dataframe input
px.line(xdf, x="xxx", y="yyy", animation_frame="epoch", color="name", title="fitted line")

Note: The raw data for the scatter seems to be missing, but I don't think that's what's in question here.

enter image description here

Upvotes: 2

Related Questions