Reputation: 3243
I have a 2d list and I want to create a Plotly multi-line chart to represent this data.
So far I've got this code:
print(len(conv_loss),len(conv_loss[0]))
print(np.array(conv_loss).shape)
conv_loss_df = pd.DataFrame(data=conv_loss, index=namelist, columns=rs)
print(conv_loss_df)
which outputs:
4 10
(4, 10)
0.0 0.1 0.2 ... 0.7 0.8 0.9
mnist 0.020498 0.123125 0.222588 ... 1.122625 1.387823 1.701249
fashion_mnist 0.232772 0.316569 0.433325 ... 1.281240 1.545556 1.830893
cifar10 0.957889 0.851946 0.921106 ... 1.645510 1.815194 2.104631
cifar100 3.516734 3.052485 3.021778 ... 3.688753 3.937770 4.599526
I would want the chart to have the x axis be rs
, the colors be namelist
and the y axis be the respective data itself.
I have played around with the examples they give for plotly line charts but I just can't seem to get it to work.
The way I've got to work it to not error is:
fig = px.line(conv_loss_df,x=rs)
fig.show()
which produces an entirely wrong graph:
Upvotes: 0
Views: 1629
Reputation: 35155
The desired graph can be drawn by targeting the data frame with the given data transposed. This was created with sample data.
import numpy as np
import pandas as pd
import plotly.express as px # plotly ver 4.14.1
import random
rs = np.arange(0.1, 1.0, 0.1)
namelist = ['mnist', 'Fasion_mnist','cifar10','cifar100']
mnist = np.random.rand(9)
fasion_mnist = np.random.rand(9)
cifar10 = np.random.rand(9)
cifar100 = np.random.rand(9)
conv_loss_df = pd.DataFrame([mnist,fasion_mnist,cifar10,cifar100], index=namelist, columns=rs)
conv_loss_df.T
fig = px.line(conv_loss_df.T)
fig.show()
Upvotes: 2