onb
onb

Reputation: 123

How to change axis titles when using sliders in plotly

I'm plotting scatter3d projections of the 4d iris data set using plotly. To display all 4 possible projections in the same figure I am using sliders. However when "sliding" from one projection to the next the axis titles do not change. Normally I would use fig.update_layout() but that isn't working. How can I get these to change with the slider?

Projection 1

Projection 2

Here's the code for reference:

import numpy as np
import plotly.graph_objects as go
from matplotlib import cm
from itertools import combinations

def nd2scatter3d(X, labels = None, features = None, plot_axes = None, hovertext = None):
    """
    Parameters
    ----------
    X : array-like, shape = (n_samples, n_features).

    labels : 1d int array, shape = (n_samples), optional, default None.
        Target or clustering labels for each sample.
        Defaults to np.ones(n_samples).

    features : list, len = n_features, optional, default None.
        List of feature names.
        Defaults to numeric labeling.

    plot_axes : list of 3-tuples, optional, default None.
        List of axes to include in 3d projections. i.e. [(0,1,2), (0,1,3)] displays
        projections along the 4th axis and 3rd axis in that order.
        Defaults to all possible axes combinations.

    hovertext : list, len = n_samples, optional, default None.
        List of text to display on mouse hover.
        Defaults to no text on hover.
    """
    if labels is None:
        labels = np.ones(X.shape[0]).astype(int)
    if features is None:
        features = np.arange(X.shape[1]).astype(str)
    if plot_axes is None:
        plot_axes = list(combinations(np.arange(X.shape[1]), 3))
    if hovertext is None:
        hoverinfo = 'none'
    else:
        hoverinfo = 'text'

    fig = go.Figure()
    for i in range(len(plot_axes)):
        fig.add_trace(
            go.Scatter3d(
                visible=False,
                x=X[:, plot_axes[i][0]],
                y=X[:, plot_axes[i][1]],
                z=X[:, plot_axes[i][2]],
                mode='markers',
                marker=dict(
                    size=3,
                    color = [list(cm.tab10.colors[c]) for c in labels],
                    opacity=1
                ),
                hovertemplate=None,
                hoverinfo= hoverinfo,
                hovertext = hovertext,
              ),)

    fig.data[0].visible = True
    steps = []
    for i in range(len(fig.data)):
        step = dict(
            method="update",
            args=[{"visible": [False] * len(fig.data)},
                  {"title": features[plot_axes[i][0]] + ' vs. ' + features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]]},  # layout attribute
                 ],
            label = str(plot_axes[i]),
                    )

        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
        steps.append(step)

    sliders = [dict(
        active=10,
        currentvalue={"prefix": "Projection: "},
        pad={"t": 10},
        steps=steps,
                )]


    fig.update_layout(
        sliders=sliders
    )
    fig.update_layout(width=900, height = 500, margin=dict(r=45, l=45, b=10, t=50),
                     showlegend=False)

    fig.update_layout(scene_aspectmode='cube',
                      scene2_aspectmode='cube',
                      scene3_aspectmode='cube',
                      scene4_aspectmode='cube',
                      scene = dict(
                        xaxis_title = features[plot_axes[0][0]],
                        yaxis_title = features[plot_axes[0][1]],
                        zaxis_title = features[plot_axes[0][2]],),
                      scene2 = dict(
                        xaxis_title = features[plot_axes[1][0]],
                        yaxis_title = features[plot_axes[1][1]],
                        zaxis_title = features[plot_axes[1][2]],),
                      scene3 = dict(
                        xaxis_title = features[plot_axes[2][0]],
                        yaxis_title = features[plot_axes[2][1]],
                        zaxis_title = features[plot_axes[2][2]],),
                      scene4 = dict(
                        xaxis_title = features[plot_axes[3][0]],
                        yaxis_title = features[plot_axes[3][1]],
                        zaxis_title = features[plot_axes[3][2]],)
                     )
    fig.show()

Solution thanks to jayveesea, as well as some minor changes:

def nd2scatter3d(X, labels = None, features = None, plot_axes = None, hovertext = None, size = 3):
    """
    Parameters
    ----------
    X : array-like, shape = (n_samples, n_features).

    labels : 1d int array, shape = (n_samples), optional, default None.
        Target or clustering labels for each sample.
        Defaults to np.ones(n_samples).

    features : list, len = n_features, optional, default None.
        List of feature names.
        Defaults to numeric labeling.

    plot_axes : list of 3-tuples, optional, default None.
        List of axes to include in 3d projections. i.e. [(0,1,2), (0,1,3)] displays
        projections along the 4th axis and 3rd axis in that order.
        Defaults to all possible axes combinations.

    hovertext : list, len = n_samples, optional, default None.
        List of text to display on mouse hover.
        Defaults to no text on hover.
    size : int, default 3.
        Sets marker size.
    """

    if labels is None:
        # Label all datapoints zero.
        labels = np.zeros(X.shape[0]).astype(int)
    if features is None:
        # numerical features if no names are passed.
        features = np.arange(X.shape[1]).astype(str)
    if plot_axes is None:
        # plot all possible axes if none are passed.
        plot_axes = list(combinations(np.arange(X.shape[1]), 3))
    if hovertext is None:
        hoverinfo = 'none'
    else:
        hoverinfo = 'text'

    # Determine colormap from number of labels.
    if len(np.unique(labels)) <= 10:
        color = [list(cm.tab10.colors[c]) if c >= 0 else [0,0,0,1] for c in labels]
    elif len(np.unique(labels)) <= 20:
        color = [list(cm.tab20.colors[c]) if c >= 0 else [0,0,0,1] for c in labels]
    else:
        norm_labels = labels/max(labels)
        color = [cm.viridis(c) if c >= 0 else [0,0,0,1] for c in norm_labels]

    # Genterate 3d scatter plot slider.
    fig = go.Figure()
    for i in range(len(plot_axes)):
        fig.add_trace(
            # Scatter plot params.
            go.Scatter3d(
                visible=False,
                x=X[:, plot_axes[i][0]],
                y=X[:, plot_axes[i][1]],
                z=X[:, plot_axes[i][2]],
                mode='markers',
                marker=dict(
                    size=size,
                    color = color,
                    opacity=1
                ),
                hovertemplate=None,
                hoverinfo= hoverinfo,
                hovertext = hovertext,
              ),)

    fig.data[0].visible = True
    steps = []

    # Slider update params.
    for i in range(len(fig.data)):
        step = dict(
            method="update",
            args=[{"visible": [False] * len(fig.data)},
                  {"title": features[plot_axes[i][0]] + ' vs. ' 
                       + features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]],
                   "scene.xaxis.title": features[plot_axes[i][0]],
                   "scene.yaxis.title": features[plot_axes[i][1]],
                   "scene.zaxis.title": features[plot_axes[i][2]],
                  },
                 ],
            label = str(plot_axes[i]),
            )

        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible".
        steps.append(step)

    sliders = [dict(
        active=10,
        currentvalue={"prefix": "Projection: (x, y, z) = "},
        pad={"t": 10},
        steps=steps,
                )]
    fig.update_layout(sliders=sliders)
    fig.update_layout(width=900, height = 500, margin=dict(r=45, l=45, b=10, t=50))
    fig.update_layout(scene_aspectmode='cube')
    fig.show()

Upvotes: 2

Views: 1332

Answers (1)

jayveesea
jayveesea

Reputation: 3199

To update the axis titles you need to include the axis names with your slider entry. It may help to reference plotly's js document on update.

So instead of this chunk:

for i in range(len(fig.data)):
        step = dict(
            method="update",
            args=[{"visible": [False] * len(fig.data)},
                  {"title": features[plot_axes[i][0]] + ' vs. ' 
                       + features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]]},
                 ],
            label = str(plot_axes[i]),
                    )

Use something like:

for i in range(len(fig.data)):
        step = dict(
            method="update",
            args=[{"visible": [False] * len(fig.data)},
                  {"title": features[plot_axes[i][0]] + ' vs. ' 
                       + features[plot_axes[i][1]] + ' vs. ' + features[plot_axes[i][2]],
                   "scene.xaxis.title": features[plot_axes[i][0]],
                   "scene.yaxis.title": features[plot_axes[i][1]],
                   "scene.zaxis.title": features[plot_axes[i][2]],
                  },
                 ],
            label = str(plot_axes[i]),
            )

This creates an entry that will update the data and title and the axes titles when the slider changes.

Upvotes: 3

Related Questions