Reputation: 1062
I have a tensor of object bounding boxes, e.g. with the shape of [10,4] which correspond to a batch of images e.g. with shape [2,3,64,64] and transformation matrices for each object with shape [10,6] and a vector that defines which object index belongs to which image. I would like to apply the affine transformations on patches of the images and replace those patches after applying the transformations. I am doing this using a for loop now, but the way I am doing it is not differntiable (I get the in place operation error from pytorch). I wanted to know if there is a differntiable way to do this. e.g. via grid_sample?
Here is my current code:
for obj_num in range(obj_vecs.shape[0]): #batch_size
im_id = obj_to_img[obj_num]
x1, y1, x2, y2 = boxes_pred[obj_num]
im_patch = img[im_id, :, x1:x2, y1:y2]
im_patch = im_patch[None, :, :, :]
img[im_id, :, x1:x2, y1:y2] = self.VITAE.stn(im_patch, theta_mean[obj_num], inverse=False)[0]
Upvotes: 1
Views: 985
Reputation: 40658
There are a few ways to perform differentiable crops in PyTorch.
Let's take a minimal example in 2D:
>>> x1, y1, x2, y2 = torch.randint(0, 9, (4,))
(tensor(7), tensor(3), tensor(5), tensor(6))
>>> x = torch.randint(0, 100, (9,9), dtype=float, requires_grad=True)
tensor([[18., 34., 28., 41., 1., 14., 77., 75., 23.],
[62., 33., 64., 41., 16., 70., 47., 45., 19.],
[20., 69., 5., 51., 1., 16., 20., 63., 52.],
[51., 25., 8., 30., 40., 67., 41., 27., 33.],
[36., 6., 95., 53., 69., 84., 51., 42., 71.],
[46., 72., 88., 82., 71., 75., 86., 36., 15.],
[66., 19., 58., 50., 91., 28., 7., 83., 4.],
[94., 50., 34., 34., 92., 45., 48., 97., 76.],
[80., 34., 19., 13., 77., 77., 51., 15., 13.]], dtype=torch.float64,
requires_grad=True)
Given x1
, x2
(resp. y1
, y2
the patch index boundaries on the height dimension (resp. width dimension). You can get the grid of coordinates corresponding do you patch using a combination of torch.arange
and torch.meshgrid
:
>>> sorted_range = lambda a, b: torch.arange(a, b) if b >= a else torch.arange(b, a)
>>> xi, yi = sorted_range(x1, x2), sorted_range(y1, y2)
(tensor([3, 4, 5, 6]), tensor([5]))
>>> i, j = torch.meshgrid(xi, yi)
(tensor([[3],
[4],
[5],
[6]]),
tensor([[5],
[5],
[5],
[5]]))
With that setup, you can extract and replace patches of x
.
You can extract the patch by indexing x
directly:
>>> patch = x[i, j].reshape(len(xi), len(yi))
tensor([[67.],
[84.],
[75.],
[28.]], dtype=torch.float64, grad_fn=<ViewBackward>)
Here is the mask for illustration purposes:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64,
grad_fn=<IndexPutBackward>)
You can replace the values in x
with the result from some transformation on the patch using torch.Tensor.index_put
:
>>> values = 2*patch
tensor([[134.],
[168.],
[150.],
[ 56.]], dtype=torch.float64, grad_fn=<MulBackward0>)
>>> x.index_put(indices=(i, j), values=values)
tensor([[ 18., 34., 28., 41., 1., 14., 77., 75., 23.],
[ 62., 33., 64., 41., 16., 70., 47., 45., 19.],
[ 20., 69., 5., 51., 1., 16., 20., 63., 52.],
[ 51., 25., 8., 30., 40., 134., 41., 27., 33.],
[ 36., 6., 95., 53., 69., 168., 51., 42., 71.],
[ 46., 72., 88., 82., 71., 150., 86., 36., 15.],
[ 66., 19., 58., 50., 91., 56., 7., 83., 4.],
[ 94., 50., 34., 34., 92., 45., 48., 97., 76.],
[ 80., 34., 19., 13., 77., 77., 51., 15., 13.]],
dtype=torch.float64, grad_fn=<IndexPutBackward>)
Upvotes: 1