Reputation: 952
I have a 3D matrix: img[i, j, k] = i+j+k
.
In my opinion, if I want the value of (1
, 2
, 3
), the grid_sample
should return 6
. But it not.
The code is:
import torch
from torch.nn import functional as F
import numpy as np
X, Y, Z = 10, 20, 30
img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
for i in range(X):
for j in range(Y):
for k in range(Z):
img[i,j,k] = i+j+k
inp = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
grid = torch.from_numpy(np.array([[1, 2, 3]], dtype=np.float32)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
grid[..., 0] /= (X-1)
grid[..., 1] /= (Y-1)
grid[..., 2] /= (Z-1)
grid = 2*grid - 1
outp = F.grid_sample(inp, grid=grid, mode='bilinear', align_corners=True)
print(outp)
The grid_sample
return 6.15
. Is there anything wrong with my code?
Upvotes: 1
Views: 1246
Reputation: 41
I encountered this issue recently, and I decide to answer this question here.
In fact, if you read the grid_sample() documentation from Pytorch, you will find out that grid_sample indeed accepts values in the order of x,y,z, not in z,y,x.
In the case of 5D inputs, grid[n, d, h, w] specifies the x, y, z pixel locations for interpolating output[n, :, d, h, w]. mode argument specifies nearest or bilinear interpolation method to sample the input pixels.
But you defined your image in an unusual way where you wrote img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
. In your example, you are having an image numpy array of dimension DimX, DimY, DimZ.
This is unusual! A 3D image is usually defined in the format of numpy array with dimension DimZ, DimY, DimX. Just imagine this, the Dimension Z is the depth, dimension Y is the height, dimension X is the width.
The grid_sample function signature said the input for 3D images (5d tensor) should have shape (N, C, D, H, W). Thus, the image is defined in my way. It is confusing that the input parameter grid has a shape of (N, Dout, Hout, Wout, 3) where at (n, d, h, w, :), the coordinates of sampling points are stored in a way of (x,y,z).
This is very annoying indeed. And you happened to make two mistakes which cancelled out themselves :)
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
Upvotes: 2
Reputation: 952
Finally, I find the solution. The reason why the above code return an incorrect value is that the torch.grid_sample accept (z, y, x) point.
Thus, the correct code should be:
import torch
from torch.nn import functional as F
import numpy as np
X, Y, Z = 10, 20, 30
img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
for i in range(X):
for j in range(Y):
for k in range(Z):
img[i,j,k] = i+j+k
inp = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
grid = torch.from_numpy(np.array([[1, 2, 3]], dtype=np.float32)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
grid[..., 0] /= (X-1)
grid[..., 1] /= (Y-1)
grid[..., 2] /= (Z-1)
grid = 2*grid - 1
newgrid = grid.clone()
newgrid[..., 0] = grid[..., 2]
newgrid[..., 1] = grid[..., 1]
newgrid[..., 2] = grid[..., 0]
outp = F.grid_sample(inp, grid=newgrid, mode='bilinear', align_corners=True)
print(outp)
Upvotes: 2