Zabir Al Nazi Nabil
Zabir Al Nazi Nabil

Reputation: 11228

pytorch 4d numpy array applying transfroms inside custom dataset

Inside my custom dataset, I want to apply transforms.Compose() to a NumPy array.

My images are in a NumPy array format with shape (num_samples, width, height, channels).

How can I apply the follwoing transforms to the full numpy array?

img_transform = transforms.Compose([ transforms.Scale((224,224)), transforms.ToTensor(), transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32]) ])

My attempts are ending in multiple errors as the transforms accept a PIL image not a 4-d NumPy array.

from torchvision import transforms
import numpy as np
import torch

img_transform = transforms.Compose([
        transforms.Scale((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

a = np.random.randint(0,256, (299,299,3))
print(a.shape)

img_transform(a)

Upvotes: 1

Views: 1876

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 33010

All torchvision transforms operate on single images, not batches of images, hence a 4D array cannot be used.

Single images given as NumPy arrays, like in your code example, can be used by converting them to a PIL image. You can simply add transforms.ToPILImage to the beginning of the transformation pipeline, as it converts either a tensor or a NumPy array to a PIL image.

img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

Note: transforms.Scale is deprecated in favour of transforms.Resize.

In your example you used np.random.randint, which by default uses type int64, but images have to be uint8. Libraries such as OpenCV return uint8 arrays when loading an image.

a = np.random.randint(0,256, (299,299,3), dtype=np.uint8)

Upvotes: 1

Related Questions