Reputation: 3085
I would like to add to the Plotly "standard legend" produced by plotly.express
each of the "graph dimensions" grouping together all traces of such characteristics. As it is probably hard to understand what I want to do from the previous description, let me put an example. I have a code that produces a plot using the following line:
px.line(
df,
x = 'x values',
y = 'y values',
color = 'Device specs', # This is what I call "color dimension".
symbol = 'Device', # This is what I call "symbol dimension".
line_dash = 'Contact type', # This is what I call "line_dash dimension".
)
and the plot looks (for some specific data) like this:
I want to add below this legend one legend for each "dimension", i.e. one legend for the color
grouping all the traces of each color, one for the symbol
and one for the line_dash
, something like this:
and, if possible, such that if I click in e.g. contact=dot it toogles the visibility of all the dashed traces together.
Is this possible with Plotly Express?
Upvotes: 4
Views: 1064
Reputation: 31166
import pandas as pd
import numpy as np
import plotly.express as px
SIZE = 10
# generate a dataset with all required attributes
df = pd.DataFrame(
{
"x values": np.tile(np.linspace(0, SIZE - 1, SIZE), SIZE),
"y values": np.sort(np.random.uniform(1, 1000, SIZE ** 2)),
"Device": np.concatenate(
[np.full(SIZE, np.random.choice([52, 36, 34], 1)) for _ in range(SIZE)]
),
"Contact type": np.concatenate(
[np.full(SIZE, np.random.choice(["dot", "ring"], 1)) for _ in range(SIZE)]
),
"Device specs": np.concatenate(
[
np.full(SIZE, np.random.choice(["laptop", "tablet", "console"], 1))
for _ in range(SIZE)
]
),
}
)
df.loc[df["x values"].eq(SIZE - 1), "y values"] = np.nan
# build the standard figure
fig = px.line(
df,
x="x values",
y="y values",
color="Device specs", # This is what I call "color dimension".
symbol="Device", # This is what I call "symbol dimension".
line_dash="Contact type", # This is what I call "line_dash dimension".
)
# build additional traces for items wanted in legend
legend_traces = [
px.line(
df,
x="x values",
y=np.full(len(df), -1000),
**param["px"],
).update_traces(**param["lg"], legendgroup=str(param["px"]))
for param in [
{"px": {"color": "Device specs"}, "lg": {"legendgrouptitle_text": "Spec"}},
{"px": {"symbol": "Device"}, "lg": {"legendgrouptitle_text": "Device"}},
{
"px": {"line_dash": "Contact type"},
"lg": {"legendgrouptitle_text": "Contact type"},
},
]
]
for t in legend_traces:
fig.add_traces(t.data)
# hide the dummy traces for extra legend entries (given y-value of -1000)
fig.update_yaxes(range=[0, df["y values"].max()])
fig.update_layout(height=500)
Upvotes: 2