Gilad Green
Gilad Green

Reputation: 37299

Plotly animated subplots with px.imshow and go.Scatter

I am trying to create a figure showing image "reconstruction" as function of number of PCs. I want to animate this to show the original image, the cumulative image (over PCs 1,...,i) and the parts that still remain to be "reconstructed". Together with that I want to show the distance between the original and reconstructed image as a function of the number of PCs.

I managed to create the figure below, which animates the scatter plot at the bottom and also the images at the top.

enter image description here

The problem is that once the animation begins the two images on the right "disappear" and I think they appear under the "Original Image"

enter image description here

This is the code I have (creation of animation frames with all 3 images and scatters, and then formation of figure):

import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
from sklearn.decomposition import PCA

pio.templates["custom"] = go.layout.Template(
    layout=go.Layout(
        margin=dict(l=20, r=20, t=40, b=0)
    )
)
pio.templates.default = "simple_white+custom"


class AnimationButtons():
    def play_scatter(frame_duration = 500, transition_duration = 300):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": False},
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "quadratic-in-out"}}])
    
    def play(frame_duration = 1000, transition_duration = 0):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": True},
                            "mode":"immediate",
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "linear"}}])
    
    def pause():
        return dict(label="Pause", method="animate", args=
                    [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])

pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T


reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        frames.append(go.Frame(
            data = [px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0],
                    go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))


fig = make_subplots(rows=2, cols=3, 
                    subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                    specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)

fig.update_layout(title=frames[0]["layout"]["title"],
                  xaxis4=dict(range=[0, 50], autorange=False),
                  yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                  margin = dict(t = 100),
                  width=800,
                  updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()

I tried finding similar questions but wasn't able to find anything that would work for the showing of both px.imshow and go.Scatter with subplots and animation.

The data X are the MNIST digits images after centering. Here is a numpy array with one such image: (X.shape=(16,5,5) - 16 images of 5x5 - animation only on first image)

X=np.array( [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]]] )

Placed the above code in a Jupyter notebook on GitHub

Upvotes: 7

Views: 6331

Answers (4)

minerharry
minerharry

Reputation: 111

TL;DR: px.imshow automatically sets the x and y axes to be the first set of axes, since it returns a whole figure; you either need to use go.Image (last codeblock) or remove the xaxis and yaxis attributes of the data.

One of the neat things about plotly is that every object is a dict at heart, which you can see when you print the object to the terminal. When you do px.imshow, the resulting figure has the following info:

In [1]: imfig = px.imshow(image, binary_string=True)

In [2]: print(imfig)
Figure({
    'data': [{'hovertemplate': 'x: %{x}<br>y: %{y}<extra></extra>',
              'name': '0',
              'source': ('data:image/png;base64,iVBORw0K' ... 'cPdQY3ZVUt1YkAAAAASUVORK5CYII='),
              'type': 'image',
              'xaxis': 'x',
              'yaxis': 'y'}],
    'layout': {'template': '...',
               'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0]},
               'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0]}}
})

Aha! In imfig.data[0], the image has xaxis and yaxis specified! This will always put the image in the first subplot. (The other subplots have axes ('x1','y1'), ('x2','y2'), etc. by default.) Indeed, in the frames list, every image has 'xaxis':'x' and 'yaxis':y, meaning that each frame will place the images in the first subplot.

Knowing that, the fix is easy: just set the x and y axis to the proper values! The new loop would look like this:

for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))   

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        imagedata = px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0]]
        for image,axis in zip(imagedata,['1','2','3']):
            image.xaxis = 'x' + axis #set the axes so that each image is 
            image.yaxis = 'y' + axis # displayed in the correct subplot
        frames.append(go.Frame(
            data = imagedata + 
                    [go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))

In fact, there's an even easier way. When plotly updates the figure using the frame, it matches each data element with the corresponding data in the current figure and then updates the data accordingly. This means if you don't specify a parameter in a frame.data dict, it will just use the default from the initial layout, which, in the case of xaxis, and yaxis, are actually set by add_traces (since you specify the rows and columns, the xaxis and yaxis properties of the frame[0].data objects are overwritten). So you can achieve the same effect without having to specify the location by just removing xaxis and yaxis altogether:

for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))   

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        imagedata = px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0]]
        for image in imagedata:
            image.xaxis = None #fields with value None are treated as 
            image.yaxis = None # deleted when converted to dict/json
        frames.append(go.Frame(
            data = imagedata + 
                    [go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))

Finally, if you want to go all the way into the weeds and skip the px.imshow altogether, you can specify the image data yourself, and change only the source. I think the minimum data you need to show the image is 'type':'image' and 'source':[source] - or, in other words, just go.Image(source=source); from the images tutorial:

The source attribute of a go.layout.Image can be the URL of an image, or a PIL Image object (from PIL import Image; img = Image.open('filename.png'))

So, the simplest solution is:

from PIL import Image
img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T

reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        frames.append(go.Frame(
            data = [go.Image(Image.fromarray(img)),
                    go.Image(Image.fromarray(img - reconstructed)), 
                    go.Image(Image.fromarray(reconstructed)),
                    go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))


fig = make_subplots(rows=2, cols=3, 
                    subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                    specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)

fig.update_layout(title=frames[0]["layout"]["title"],
                  xaxis4=dict(range=[0, 50], autorange=False),
                  yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                  margin = dict(t = 100),
                  width=800,
                  updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()

--EDIT--

For some reason, it seems plotly has either removed or broken PIL image interoperability for go.Image; I can't get any of the examples on the documentation or online to work. the px.imshow way works, though, so I think the most reliable way to get the image data out of px.imshow would be to only use the source:

img_data = px.imshow(image,binary_string=True).data[0].source
image = go.Image(source=img_data)
frame.data = [image]

Upvotes: 0

Gilad Green
Gilad Green

Reputation: 37299

Similar to what jayvessea suggested, I ended up playing with the structure of the px.imshow. I first created the px.imshow with both facets and animation, and then added to it both the scatter plot and the desired layout

pca = PCA(n_components=50).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[150], pca.transform(X[150].reshape(-1, 1)).T

reconstructed, distortion, images, scatters, titles = np.zeros_like(X[0]), [], [], [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every other reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        images.append([img.copy(), reconstructed.copy(), (img - reconstructed).copy()])
        scatters.append(go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion, name=3, xaxis="x4", yaxis="y4", marker_color="black"))
        titles.append(rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")


        
# Create figure on the basis of the animated facetted imshow figure
fig = px.imshow(np.array(images), facet_col=1, animation_frame=0, binary_string=True)
for i, (scatter, title) in enumerate(zip(*[scatters, titles])):
    fig["frames"][i]["data"] += (scatter, )
    fig["frames"][i]["traces"] = [0,1,2,3]
    fig["frames"][i]["layout"]["title"] = title 
fig.add_traces(data=fig["frames"][0]["data"][-1])

# Create "template" figure to transfer layout onto the `fig` figure
layout = make_subplots(rows=2, cols=3, 
                       subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                       specs=[[{"type":"Image"}, {"type":"Image"}, {"type":"Image"}], [{"type":"xy","colspan": 3}, None, None]], row_heights=[500, 200],)

layout.update_layout(title=titles[0],
                     xaxis4=dict(range=[0, 50], autorange=False),
                     yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                     margin = dict(t = 100), width=800,
                     updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])

fig["layout"] = layout["layout"]
fig

It is not a very elegant solution but it is a sufficient workaround.

enter image description here

Upvotes: 3

jayveesea
jayveesea

Reputation: 3199

While this is not a complete solution it may help get there...

Using animation_frame and facet_col you can build the upper part of the figure using facets. Unfortunately I'm not sure how to link this to an animated scatter plot. You could create scatter images and then tie them into this, but then you loose the ability to hover in the scatter and get info.

But, this may be of some value if you inspect the output print(fig0), and compare it to yours print(fig).

# X = see above

pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T
ORIG, DIFF, RECN, = [],[],[]

reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
      ORIG = np.append(ORIG,img)
      DIFF = np.append(DIFF,(img - reconstructed).copy())
      RECN = np.append(RECN,reconstructed.copy())

DATA = np.array([np.reshape(ORIG,(8,16,16)),
               np.reshape(DIFF,(8,16,16)),
               np.reshape(RECN,(8,16,16))])

fig0 = px.imshow(DATA, animation_frame=1, facet_col=0, binary_string=True)
fig0.show()

print(fig0) #inspect the layout

enter image description here

Upvotes: 2

vestland
vestland

Reputation: 61094

Answer in progress...


This is what your github code produces on my end:

Initially...

enter image description here

And after animation ends:

enter image description here

Upvotes: 1

Related Questions