Reputation: 1403
I have a tensor img
in PyTorch of size bx2xhxw
and want to upsample it using torch.nn.functional.interpolate
. But while interpolation I do not wish channel 1 to use information from channel 2. To do this should I do,
img2 = torch.rand(b,2,2*h,2*w) # create a random torch tensor.
img2[:,0,:,:] = nn.functional.interpolate(img[:,0,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img2[:,1,:,:] = nn.functional.interpolate(img[:,1,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img=img2
or simply using
img = nn.functional.interpolate(img, [2*h,2*w], mode='bilinear', align_corners=True)
will solve my purpose.
Upvotes: 7
Views: 13464
Reputation: 4826
You should use (2). There is no communication in the first and second dimensions (batch and channel respectively) for all types of interpolation (1D, 2D, 3D), as they should be.
Simple example:
import torch
import torch.nn.functional as F
b = 2
c = 4
h = w = 8
a = torch.randn((b, c, h, w))
a_upsample = F.interpolate(a, [h*2, w*2], mode='bilinear', align_corners=True)
a_mod = a.clone()
a_mod[:, 0] *= 1000
a_mod_upsample = F.interpolate(a_mod, [h*2, w*2], mode='bilinear', align_corners=True)
print(torch.isclose(a_upsample[:,0], a_mod_upsample[:,0]).all())
print(torch.isclose(a_upsample[:,1], a_mod_upsample[:,1]).all())
print(torch.isclose(a_upsample[:,2], a_mod_upsample[:,2]).all())
print(torch.isclose(a_upsample[:,3], a_mod_upsample[:,3]).all())
Output:
tensor(False)
tensor(True)
tensor(True)
tensor(True)
One can tell that a large change in the first channel has no effect in other channels.
Upvotes: 6