Reputation: 129
I want to visualize the process of fitting a straight line to a 2D dataset. For every epoch, I have the x and y coordinates of the starting point and the endpoint of the straight. Below is an example image of a dataset with a line (matplotlib).
I saw that plotly provides an option to animate lines.
The above plot was created with code similar to this:
# this should animate my line
fig_data = px.line(df, x="xxx", y="yyy", animation_frame="epoch", animation_group="name", title="fitted line")
# responsible for the red scatter plot points:
fig_data.add_traces(
go.Scatter(
x=xxx,
y=yyy, mode='markers', name='House in Dataset')
)
The dataframe looks like that:
epoch xxx yyy name
0 0 [0.5, 4] [1.4451884285714285, 4.730202428571428] example
1 1 [0.5, 4] [1.3944818842653062, 4.4811159469795925] example
2 2 [0.5, 4] [1.3475661354539474, 4.251154573663417] example
3 3 [0.5, 4] [1.3041510122346094, 4.03885377143571] example
So the line that should be shown in epoch 0 starts from (0.5,1.44) and goes to (4,4.73). However, no line is rendered. What should I change?
Upvotes: 3
Views: 158
Reputation: 3199
I believe the problem is with the format of xxx
and yyy
in df
, as they are now they're nested arrays that don't seem to be referenced correctly by px.line
.
You can use pd.Series.explode
to "flatten" this data frame (example here) and then use that as the input to px.line
. See here for more info on panda's explode.
Using xdf=df.set_index('epoch').apply(pd.Series.explode).reset_index()
will yield:
epoch xxx yyy name
0 0 0.5 1.445188 example
1 0 4 4.730202 example
2 1 0.5 1.394482 example
3 1 4 4.481116 example
4 2 0.5 1.347566 example
5 2 4 4.251155 example
6 3 0.5 1.304151 example
7 3 4 4.038854 example
Full example with comments:
import plotly.express as px
import pandas as pd
data = {'epoch': [0,1,2,3],
'xxx': [[0.5, 4], [0.5, 4], [0.5, 4], [0.5, 4]],
'yyy': [[1.4451884285714285, 4.7302024285714280],
[1.3944818842653062, 4.4811159469795925],
[1.3475661354539474, 4.2511545736634170],
[1.3041510122346094, 4.0388537714357100]],
'name':['example','example','example','example']}
df = pd.DataFrame.from_dict(data)
# now exploding `df`
xdf=df.set_index('epoch').apply(pd.Series.explode).reset_index()
# now plotting using xdf as dataframe input
px.line(xdf, x="xxx", y="yyy", animation_frame="epoch", color="name", title="fitted line")
Note: The raw data for the scatter seems to be missing, but I don't think that's what's in question here.
Upvotes: 2