Mohit Lamba
Mohit Lamba

Reputation: 1403

What information does Pytorch nn.functional.interpolate use?

I have a tensor img in PyTorch of size bx2xhxw and want to upsample it using torch.nn.functional.interpolate. But while interpolation I do not wish channel 1 to use information from channel 2. To do this should I do,

img2 = torch.rand(b,2,2*h,2*w) # create a random torch tensor.
img2[:,0,:,:] = nn.functional.interpolate(img[:,0,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img2[:,1,:,:] = nn.functional.interpolate(img[:,1,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img=img2

or simply using

img = nn.functional.interpolate(img, [2*h,2*w], mode='bilinear', align_corners=True)

will solve my purpose.

Upvotes: 7

Views: 13464

Answers (1)

hkchengrex
hkchengrex

Reputation: 4826

You should use (2). There is no communication in the first and second dimensions (batch and channel respectively) for all types of interpolation (1D, 2D, 3D), as they should be.

Simple example:

import torch
import torch.nn.functional as F

b = 2
c = 4
h = w = 8

a = torch.randn((b, c, h, w))
a_upsample = F.interpolate(a, [h*2, w*2], mode='bilinear', align_corners=True)

a_mod = a.clone()
a_mod[:, 0] *= 1000
a_mod_upsample = F.interpolate(a_mod, [h*2, w*2], mode='bilinear', align_corners=True)

print(torch.isclose(a_upsample[:,0], a_mod_upsample[:,0]).all())
print(torch.isclose(a_upsample[:,1], a_mod_upsample[:,1]).all())
print(torch.isclose(a_upsample[:,2], a_mod_upsample[:,2]).all())
print(torch.isclose(a_upsample[:,3], a_mod_upsample[:,3]).all())

Output:

tensor(False)
tensor(True)
tensor(True)
tensor(True)

One can tell that a large change in the first channel has no effect in other channels.

Upvotes: 6

Related Questions