Azade Farshad
Azade Farshad

Reputation: 1062

Differentiable affine transformation on patches of images in pytorch

I have a tensor of object bounding boxes, e.g. with the shape of [10,4] which correspond to a batch of images e.g. with shape [2,3,64,64] and transformation matrices for each object with shape [10,6] and a vector that defines which object index belongs to which image. I would like to apply the affine transformations on patches of the images and replace those patches after applying the transformations. I am doing this using a for loop now, but the way I am doing it is not differntiable (I get the in place operation error from pytorch). I wanted to know if there is a differntiable way to do this. e.g. via grid_sample?

Here is my current code:

for obj_num in range(obj_vecs.shape[0]): #batch_size
    im_id = obj_to_img[obj_num]
    x1, y1, x2, y2 = boxes_pred[obj_num]
    im_patch = img[im_id, :, x1:x2, y1:y2]
    im_patch = im_patch[None, :, :, :]
    img[im_id, :, x1:x2, y1:y2] = self.VITAE.stn(im_patch, theta_mean[obj_num], inverse=False)[0]

Upvotes: 1

Views: 985

Answers (1)

Ivan
Ivan

Reputation: 40658

There are a few ways to perform differentiable crops in PyTorch.

Let's take a minimal example in 2D:

>>> x1, y1, x2, y2 = torch.randint(0, 9, (4,))
(tensor(7), tensor(3), tensor(5), tensor(6))

>>> x = torch.randint(0, 100, (9,9), dtype=float, requires_grad=True)
tensor([[18., 34., 28., 41.,  1., 14., 77., 75., 23.],
        [62., 33., 64., 41., 16., 70., 47., 45., 19.],
        [20., 69.,  5., 51.,  1., 16., 20., 63., 52.],
        [51., 25.,  8., 30., 40., 67., 41., 27., 33.],
        [36.,  6., 95., 53., 69., 84., 51., 42., 71.],
        [46., 72., 88., 82., 71., 75., 86., 36., 15.],
        [66., 19., 58., 50., 91., 28.,  7., 83.,  4.],
        [94., 50., 34., 34., 92., 45., 48., 97., 76.],
        [80., 34., 19., 13., 77., 77., 51., 15., 13.]], dtype=torch.float64,
       requires_grad=True)

Given x1, x2 (resp. y1, y2 the patch index boundaries on the height dimension (resp. width dimension). You can get the grid of coordinates corresponding do you patch using a combination of torch.arange and torch.meshgrid:

>>> sorted_range = lambda a, b: torch.arange(a, b) if b >= a else torch.arange(b, a)
>>> xi, yi = sorted_range(x1, x2), sorted_range(y1, y2)
(tensor([3, 4, 5, 6]), tensor([5]))

>>> i, j = torch.meshgrid(xi, yi)
(tensor([[3],
         [4],
         [5],
         [6]]), 
 tensor([[5],
         [5],
         [5],
         [5]]))

With that setup, you can extract and replace patches of x.

  1. You can extract the patch by indexing x directly:

    >>> patch = x[i, j].reshape(len(xi), len(yi))
    tensor([[67.],
            [84.],
            [75.],
            [28.]], dtype=torch.float64, grad_fn=<ViewBackward>)
    

    Here is the mask for illustration purposes:

    tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64,
    grad_fn=<IndexPutBackward>)
    
  2. You can replace the values in x with the result from some transformation on the patch using torch.Tensor.index_put:

    >>> values = 2*patch
     tensor([[134.],
             [168.],
             [150.],
             [ 56.]], dtype=torch.float64, grad_fn=<MulBackward0>)
    
    >>> x.index_put(indices=(i, j), values=values)
    tensor([[ 18.,  34.,  28.,  41.,   1.,  14.,  77.,  75.,  23.],
            [ 62.,  33.,  64.,  41.,  16.,  70.,  47.,  45.,  19.],
            [ 20.,  69.,   5.,  51.,   1.,  16.,  20.,  63.,  52.],
            [ 51.,  25.,   8.,  30.,  40., 134.,  41.,  27.,  33.],
            [ 36.,   6.,  95.,  53.,  69., 168.,  51.,  42.,  71.],
            [ 46.,  72.,  88.,  82.,  71., 150.,  86.,  36.,  15.],
            [ 66.,  19.,  58.,  50.,  91.,  56.,   7.,  83.,   4.],
            [ 94.,  50.,  34.,  34.,  92.,  45.,  48.,  97.,  76.],
            [ 80.,  34.,  19.,  13.,  77.,  77.,  51.,  15.,  13.]],
        dtype=torch.float64, grad_fn=<IndexPutBackward>)
    

Upvotes: 1

Related Questions