Rylan Schaeffer
Rylan Schaeffer

Reputation: 2705

PyTorch - How to use Avg 2d Pooling as a dataset transform?

In Pytorch, I have a dataset of 2D images (or alternatively, 1 channel images) and I'd like to apply average 2D pooling as a transform. How do I do this? The following does not work:

    omniglot_dataset = torchvision.datasets.Omniglot(
        root=data_dir,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            # torchvision.transforms.Resize((10, 10))
            torch.nn.functional.avg_pool2d(kernel_size=3, strides=1),
        ])
    )

Upvotes: 1

Views: 827

Answers (2)

Rylan Schaeffer
Rylan Schaeffer

Reputation: 2705

yutasrobot's answer above is perfectly satisfactory. Another answer I received on the PyTorch forum can be found at https://discuss.pytorch.org/t/how-to-use-avg-2d-pooling-as-a-dataset-transform/117995/2.

"""

You can use transforms.Lambda to call the functional API:

transform=torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop((80, 80)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=1)),
])

img = transforms.ToPILImage()(torch.randn(3, 224, 224))
out = transform(img)

"""

Upvotes: 1

yutasrobot
yutasrobot

Reputation: 2496

Transforms have to be a callable object. But torch.nn.functional.avg_pool2d doesn't return a callable object, but rather it is just a function you can call to process, that is why they are packaged under torch.nn.functional where all functionals receives the input and parameters. You need to use the other version:

torch.nn.AvgPool2d(kernel_size=3, stride=1)

Which returns a callable object, that can be called to process a given input, for example:

pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
output = pooler(input)

With this change here you can see different versions how you can use callable version:

import torchvision
import torch
import matplotlib.pyplot as plt

omniglotv1 = torchvision.datasets.Omniglot(
        root='./dataset/',
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80))
        ])
    )

x1, y = omniglotv1[0]
print(x1.size())      # torch.Size([1, 80, 80])

omniglotv2 = torchvision.datasets.Omniglot(
        root='./dataset/',
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            torch.nn.AvgPool2d(kernel_size=3, stride=1)
        ])
    )

x2, y = omniglotv2[0]
print(x2.size())      # torch.Size([1, 78, 78])

pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
omniglotv3 = torchvision.datasets.Omniglot(
        root='./dataset/',
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            pooler
        ])
    )

x3, y = omniglotv3[0]
print(x3.size())      # torch.Size([1, 78, 78])

Here, I just added a short code for image printing to see how the transform looks:

x_img   = x1.squeeze().cpu().numpy()
ave_img = x2.squeeze().cpu().numpy()
combined = np.zeros((158,80))
combined[0:80,0:80] = x_img
combined[80:,0:78] = ave_img
plt.imshow(combined)
plt.show()

enter image description here

Upvotes: 1

Related Questions