Qiang Zhang
Qiang Zhang

Reputation: 952

pytorch's grid_sample return an incorrect value

I have a 3D matrix: img[i, j, k] = i+j+k.

In my opinion, if I want the value of (1, 2, 3), the grid_sample should return 6. But it not.

The code is:

import torch
from torch.nn import functional as F
import numpy as np
X, Y, Z = 10, 20, 30
img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
for i in range(X):
    for j in range(Y):
        for k in range(Z):
            img[i,j,k] = i+j+k
inp = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
grid = torch.from_numpy(np.array([[1, 2, 3]], dtype=np.float32)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
grid[..., 0] /= (X-1)
grid[..., 1] /= (Y-1)
grid[..., 2] /= (Z-1)
grid = 2*grid - 1
outp = F.grid_sample(inp, grid=grid, mode='bilinear', align_corners=True)
print(outp)

The grid_sample return 6.15. Is there anything wrong with my code?

Upvotes: 1

Views: 1246

Answers (2)

user1836485
user1836485

Reputation: 41

I encountered this issue recently, and I decide to answer this question here.

In fact, if you read the grid_sample() documentation from Pytorch, you will find out that grid_sample indeed accepts values in the order of x,y,z, not in z,y,x.

In the case of 5D inputs, grid[n, d, h, w] specifies the x, y, z pixel locations for interpolating output[n, :, d, h, w]. mode argument specifies nearest or bilinear interpolation method to sample the input pixels.

But you defined your image in an unusual way where you wrote img = np.zeros(shape=[X, Y, Z], dtype=np.float32). In your example, you are having an image numpy array of dimension DimX, DimY, DimZ.

This is unusual! A 3D image is usually defined in the format of numpy array with dimension DimZ, DimY, DimX. Just imagine this, the Dimension Z is the depth, dimension Y is the height, dimension X is the width.

The grid_sample function signature said the input for 3D images (5d tensor) should have shape (N, C, D, H, W). Thus, the image is defined in my way. It is confusing that the input parameter grid has a shape of (N, Dout, Hout, Wout, 3) where at (n, d, h, w, :), the coordinates of sampling points are stored in a way of (x,y,z).

This is very annoying indeed. And you happened to make two mistakes which cancelled out themselves :)

https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

Upvotes: 2

Qiang Zhang
Qiang Zhang

Reputation: 952

Finally, I find the solution. The reason why the above code return an incorrect value is that the torch.grid_sample accept (z, y, x) point.

Thus, the correct code should be:

import torch
from torch.nn import functional as F
import numpy as np
X, Y, Z = 10, 20, 30
img = np.zeros(shape=[X, Y, Z], dtype=np.float32)
for i in range(X):
    for j in range(Y):
        for k in range(Z):
            img[i,j,k] = i+j+k
inp = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
grid = torch.from_numpy(np.array([[1, 2, 3]], dtype=np.float32)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
grid[..., 0] /= (X-1)
grid[..., 1] /= (Y-1)
grid[..., 2] /= (Z-1)

grid = 2*grid - 1

newgrid = grid.clone()
newgrid[..., 0] = grid[..., 2]
newgrid[..., 1] = grid[..., 1]
newgrid[..., 2] = grid[..., 0]

outp = F.grid_sample(inp, grid=newgrid, mode='bilinear', align_corners=True)
print(outp)

Upvotes: 2

Related Questions