Peter Minev
Peter Minev

Reputation: 1

How to do scattering2D of an image file?

I am trying to do wavelet scattering of a 2d image with the following code:

#import torch
from kymatio import Scattering2D
import numpy as np
import PIL
from PIL import Image

FILENAME = "./square.png"
image = PIL.Image.open(FILENAME).convert("L")

a = np.array(image).astype(np.float64)
x = torch.from_numpy(a)
imageSize = x.shape

print( imageSize )

scattering = Scattering2D(J=2, shape=imageSize, frontend='numpy', L=8)

Sx = scattering(x)

print(Sx.size())

and get the following error messages. Can anyone help?

Traceback (most recent call last): File "/Users/pminev/Desktop/scatter2d.py", line 18, in Sx = scattering(x) ^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/frontend/torch_frontend.py", line 22, in forward return self.scattering(x) ^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/scattering2d/frontend/torch_frontend.py", line 98, in scattering S = scattering2d(input, self.pad, self.unpad, self.backend, self.J, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/scattering2d/core/scattering2d.py", line 19, in scattering2d U_1_c = cdgmm(U_0_c, phi['levels'][0]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/backend/torch_backend.py", line 192, in cdgmm raise TypeError('Input and filter must be of the same dtype.') TypeError: Input and filter must be of the same dtype.

Trying to do wavelet scattering of of an image.

Upvotes: 0

Views: 95

Answers (1)

Sebastien Grand
Sebastien Grand

Reputation: 61

You could try using the kymatio call to numpy like that. The use of torch seems to be the problem

import numpy as np
import matplotlib.pyplot as plt
import kymatio.numpy as kp

# Example image (2D array)
FILENAME = "./square.png"
pil_img = PIL.Image.open(FILENAME).convert("L")
img = np.array(pil_img).astype(float)

# Define wavelet scattering parameters
J = 2  # Number of scales
L = 8  # Number of angles
image_size = img.shape[0]

# Compute wavelet scattering transform
scattering = kp.Scattering2D(J=J, shape=(image_size, image_size), L=L)
scattering_coeffs = scattering(img)

# Plot the original image
plt.figure(figsize=(8, 4))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap='gray')
plt.title('Original Image')
plt.axis('off')

# Plot the first-order scattering coefficients
plt.subplot(1, 3, 2)
plt.imshow(scattering_coeffs[1], cmap='viridis')
plt.title('First-order Scattering Coefficients')
plt.axis('off')

# Plot the second-order scattering coefficients
plt.subplot(1, 3, 3)
plt.imshow(scattering_coeffs[2], cmap='viridis')
plt.title('Second-order Scattering Coefficients')
plt.axis('off')

plt.tight_layout()
plt.show()

Upvotes: 0

Related Questions