Sergiy Snake
Sergiy Snake

Reputation: 21

How to optimize the custom bilinear sampling alternative to grid_sample for TensorRT inference?

I am trying to covert the model with torch.nn.functional.grid_sample from Pytorch (1.6) to TensorRT (7) through ONNX (opset 11). Opset 11 does not support grid_sample conversion. Custom alternative I found (https://github.com/pytorch/pytorch/issues/27212) is extremely slow while running in Pytorch and have the problem with converting the main loop to TRT.

My own implementation of bilinear sampling (not just grid_sample, but the whole original sampling, based on grid_sample) performs much faster in Pytorch and is converted to TRT successfully. But my custom bilinear sampling in TRT is slower, than the one in Pytorch (5.6 ms vs 2.0 ms). It turns out, that Pytorch image[:, ind, y0, x0] indexing produce Gather layer with running time about 0.97 ms. And there are 4 such layers in the TRT version of such bilinear sampling.

So the questions are:

Here is the code of bilinear sampling function:

def bilinear_sample_noloop(image, grid):
    """
    :param image: sampling source of shape [N, C, H, W]
    :param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2]
    :return: sampling result of shape [N, C, grid_H, grid_W]
    """
    Nt, C, H, W = image.shape
    grid_H = grid.shape[1]
    grid_W = grid.shape[2]
    xgrid, ygrid = grid.split([1, 1], dim=-1)
    mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
    x0 = torch.floor(xgrid)
    x1 = x0 + 1
    y0 = torch.floor(ygrid)
    y1 = y0 + 1
    wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2)
    wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2)
    x0 = (x0 * mask).view(Nt, grid_H, grid_W).long()
    y0 = (y0 * mask).view(Nt, grid_H, grid_W).long()
    x1 = (x1 * mask).view(Nt, grid_H, grid_W).long()
    y1 = (y1 * mask).view(Nt, grid_H, grid_W).long()
    ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device)
    ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long()
    image = image.permute(1, 0, 2, 3)
    output_tensor = (image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + \
                 image[:, ind, y1, x1] * wd).permute(1, 0, 2, 3)
    output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1)
    image = image.permute(1, 0, 2, 3)
    return output_tensor, mask

Time profiling parameters:

A part of TRT model profiling with trtexec:

     Layer   Time (ms)   Avg. Time (ms)   Time %
...
   Mul_146        5.82             0.03      0.5
   Add_147        8.50             0.04      0.7
Gather_148      214.39             0.97     17.3
Gather_174      214.25             0.97     17.3
Gather_201      213.88             0.97     17.3
Gather_228      214.48             0.97     17.3
 Add_237))       25.01             0.11      2.0
   Mul_251        7.84             0.04      0.6
     Total     1238.40             5.60    100.0

Additionally I tried viewing the image as the linear array over all dimensions except C and creating the linear indexes to adress elements in the form image[:, p0]. And for this case Gather becomes even slower (about 1.07 ms). Then I considered C=1 (as it always is in the original model) and address the tensor elements as image[p0]. This time Gather takes about 0.92 ms (still too slow).

Upvotes: 2

Views: 2432

Answers (1)

loran d
loran d

Reputation: 98

The Following Code can be used to convert from Pytorch to TensorRT as a bilinear_interpolate of an image

def bilinear_interpolate_torch(im, y, x):
'''
   im : B,C,H,W
   y : 1,numPoints -- pixel location y float
   x : 1,numPOints -- pixel location y float
'''


x0 = torch.floor(x).type(torch.cuda.LongTensor)
x1 = x0 + 1

y0 = torch.floor(y).type(torch.cuda.LongTensor)
y1 = y0 + 1

wa = (x1.type(torch.cuda.FloatTensor) - x) * (y1.type(torch.cuda.FloatTensor) - y)
wb = (x1.type(torch.cuda.FloatTensor) - x) * (y - y0.type(dtype))
wc = (x - x0.type(torch.cuda.FloatTensor)) * (y1.type(torch.cuda.FloatTensor) - y)
wd = (x - x0.type(torch.cuda.FloatTensor)) * (y - y0.type(torch.cuda.FloatTensor))
# Instead of clamp
x1 = x1 - torch.floor(x1 / im.shape[3]).int()
y1 = y1 - torch.floor(y1 / im.shape[2]).int()
Ia = im[:, :, y0, x0]
Ib = im[:, :, y1, x0]
Ic = im[:, :, y0, x1]
Id = im[:, :, y1, x1]

return Ia  * wa + Ib * wb + Ic * wc + Id * wd

Upvotes: 0

Related Questions