MaiTruongSon
MaiTruongSon

Reputation: 97

torch.nn.fucntional.interpolate(): Parameters Settings

I'm using torch.nn.functional.interpolate() to resize an image.

Firstly I use transforms.ToTensor() to transform an image into a tensor, which have size of (3, 252, 252), (252, 252) is the size of the imported image. What I want to do is to create a tensor with size (3, 504, 504) with interpolate() function.

I set the para scale_factor=2, but it returned a (3, 252, 504) tensor. Then I set it like scale_factor=(1,2,2) and received an error of dimensional conflict like this:

size shape must match input shape. Input is 1D, size is 3

So what way should I do to set the parameters in order to receive (3, 504, 504) tensor?

Upvotes: 3

Views: 4334

Answers (1)

yakhyo
yakhyo

Reputation: 1656

If you're using scale_factor you need to give batch of images not single image. So you need to add one batch by using unsqueeze(0) then give it to interpolate function as follows:

import torch
import torch.nn.functional as F

img = torch.randn(3, 252, 252)  # torch.Size([3, 252, 252])
img = img.unsqueeze(0)  # torch.Size([1, 3, 252, 252])

out = F.interpolate(img, scale_factor=(2, 2), mode='nearest')
print(out.size()) # torch.Size([1, 3, 504, 504])

Upvotes: 6

Related Questions