Reputation: 123
I'm plotting scatter3d projections of the 4d iris data set using plotly. To display all 4 possible projections in the same figure I am using sliders. However when "sliding" from one projection to the next the axis titles do not change. Normally I would use fig.update_layout()
but that isn't working. How can I get these to change with the slider?
Here's the code for reference:
import numpy as np
import plotly.graph_objects as go
from matplotlib import cm
from itertools import combinations
def nd2scatter3d(X, labels = None, features = None, plot_axes = None, hovertext = None):
"""
Parameters
----------
X : array-like, shape = (n_samples, n_features).
labels : 1d int array, shape = (n_samples), optional, default None.
Target or clustering labels for each sample.
Defaults to np.ones(n_samples).
features : list, len = n_features, optional, default None.
List of feature names.
Defaults to numeric labeling.
plot_axes : list of 3-tuples, optional, default None.
List of axes to include in 3d projections. i.e. [(0,1,2), (0,1,3)] displays
projections along the 4th axis and 3rd axis in that order.
Defaults to all possible axes combinations.
hovertext : list, len = n_samples, optional, default None.
List of text to display on mouse hover.
Defaults to no text on hover.
"""
if labels is None:
labels = np.ones(X.shape[0]).astype(int)
if features is None:
features = np.arange(X.shape[1]).astype(str)
if plot_axes is None:
plot_axes = list(combinations(np.arange(X.shape[1]), 3))
if hovertext is None:
hoverinfo = 'none'
else:
hoverinfo = 'text'
fig = go.Figure()
for i in range(len(plot_axes)):
fig.add_trace(
go.Scatter3d(
visible=False,
x=X[:, plot_axes[i][0]],
y=X[:, plot_axes[i][1]],
z=X[:, plot_axes[i][2]],
mode='markers',
marker=dict(
size=3,
color = [list(cm.tab10.colors[c]) for c in labels],
opacity=1
),
hovertemplate=None,
hoverinfo= hoverinfo,
hovertext = hovertext,
),)
fig.data[0].visible = True
steps = []
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)},
{"title": features[plot_axes[i][0]] + ' vs. ' + features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]]}, # layout attribute
],
label = str(plot_axes[i]),
)
step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
steps.append(step)
sliders = [dict(
active=10,
currentvalue={"prefix": "Projection: "},
pad={"t": 10},
steps=steps,
)]
fig.update_layout(
sliders=sliders
)
fig.update_layout(width=900, height = 500, margin=dict(r=45, l=45, b=10, t=50),
showlegend=False)
fig.update_layout(scene_aspectmode='cube',
scene2_aspectmode='cube',
scene3_aspectmode='cube',
scene4_aspectmode='cube',
scene = dict(
xaxis_title = features[plot_axes[0][0]],
yaxis_title = features[plot_axes[0][1]],
zaxis_title = features[plot_axes[0][2]],),
scene2 = dict(
xaxis_title = features[plot_axes[1][0]],
yaxis_title = features[plot_axes[1][1]],
zaxis_title = features[plot_axes[1][2]],),
scene3 = dict(
xaxis_title = features[plot_axes[2][0]],
yaxis_title = features[plot_axes[2][1]],
zaxis_title = features[plot_axes[2][2]],),
scene4 = dict(
xaxis_title = features[plot_axes[3][0]],
yaxis_title = features[plot_axes[3][1]],
zaxis_title = features[plot_axes[3][2]],)
)
fig.show()
Solution thanks to jayveesea, as well as some minor changes:
def nd2scatter3d(X, labels = None, features = None, plot_axes = None, hovertext = None, size = 3):
"""
Parameters
----------
X : array-like, shape = (n_samples, n_features).
labels : 1d int array, shape = (n_samples), optional, default None.
Target or clustering labels for each sample.
Defaults to np.ones(n_samples).
features : list, len = n_features, optional, default None.
List of feature names.
Defaults to numeric labeling.
plot_axes : list of 3-tuples, optional, default None.
List of axes to include in 3d projections. i.e. [(0,1,2), (0,1,3)] displays
projections along the 4th axis and 3rd axis in that order.
Defaults to all possible axes combinations.
hovertext : list, len = n_samples, optional, default None.
List of text to display on mouse hover.
Defaults to no text on hover.
size : int, default 3.
Sets marker size.
"""
if labels is None:
# Label all datapoints zero.
labels = np.zeros(X.shape[0]).astype(int)
if features is None:
# numerical features if no names are passed.
features = np.arange(X.shape[1]).astype(str)
if plot_axes is None:
# plot all possible axes if none are passed.
plot_axes = list(combinations(np.arange(X.shape[1]), 3))
if hovertext is None:
hoverinfo = 'none'
else:
hoverinfo = 'text'
# Determine colormap from number of labels.
if len(np.unique(labels)) <= 10:
color = [list(cm.tab10.colors[c]) if c >= 0 else [0,0,0,1] for c in labels]
elif len(np.unique(labels)) <= 20:
color = [list(cm.tab20.colors[c]) if c >= 0 else [0,0,0,1] for c in labels]
else:
norm_labels = labels/max(labels)
color = [cm.viridis(c) if c >= 0 else [0,0,0,1] for c in norm_labels]
# Genterate 3d scatter plot slider.
fig = go.Figure()
for i in range(len(plot_axes)):
fig.add_trace(
# Scatter plot params.
go.Scatter3d(
visible=False,
x=X[:, plot_axes[i][0]],
y=X[:, plot_axes[i][1]],
z=X[:, plot_axes[i][2]],
mode='markers',
marker=dict(
size=size,
color = color,
opacity=1
),
hovertemplate=None,
hoverinfo= hoverinfo,
hovertext = hovertext,
),)
fig.data[0].visible = True
steps = []
# Slider update params.
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)},
{"title": features[plot_axes[i][0]] + ' vs. '
+ features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]],
"scene.xaxis.title": features[plot_axes[i][0]],
"scene.yaxis.title": features[plot_axes[i][1]],
"scene.zaxis.title": features[plot_axes[i][2]],
},
],
label = str(plot_axes[i]),
)
step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible".
steps.append(step)
sliders = [dict(
active=10,
currentvalue={"prefix": "Projection: (x, y, z) = "},
pad={"t": 10},
steps=steps,
)]
fig.update_layout(sliders=sliders)
fig.update_layout(width=900, height = 500, margin=dict(r=45, l=45, b=10, t=50))
fig.update_layout(scene_aspectmode='cube')
fig.show()
Upvotes: 2
Views: 1332
Reputation: 3199
To update the axis titles you need to include the axis names with your slider entry. It may help to reference plotly's js document on update
.
So instead of this chunk:
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)},
{"title": features[plot_axes[i][0]] + ' vs. '
+ features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]]},
],
label = str(plot_axes[i]),
)
Use something like:
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)},
{"title": features[plot_axes[i][0]] + ' vs. '
+ features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]],
"scene.xaxis.title": features[plot_axes[i][0]],
"scene.yaxis.title": features[plot_axes[i][1]],
"scene.zaxis.title": features[plot_axes[i][2]],
},
],
label = str(plot_axes[i]),
)
This creates an entry that will update the data and title and the axes titles when the slider changes.
Upvotes: 3