N.Z
N.Z

Reputation: 19

pytorch (numpy) calculation about the closest pixels to points

I am trying to solve a complicated problem.

For example, I have a batch of 2D predicted images (softmax output, value between 0 and 1) with size: Batch x H x W and ground truth Batch x H x W

enter image description here

The light gray color pixels are the background with value 0, and the dark gray color pixels are the foreground with value 1. I try to compute the mass center coordinates using scipy.ndimage.center_of_mass on each ground truth image. Then I get the center location point C (red color) for each ground truth. The C points set is Batch x 1.

Now, for each pixel A (yellow color) in the predicted images, I want to get three pixels B1, B2, B3 (blue color) which are the closest to A on the line AC (here C is corresponding location of mass center in ground truth).

I used following code to get the three closest points B1, B2, B3.

def connect(ends, m=3):
    d0, d1 = np.abs(np.diff(ends, axis=0))[0]
    if d0 > d1:
        return np.c_[np.linspace(ends[0, 0], ends[1, 0], m + 1, dtype=np.int32),
                 np.round(np.linspace(ends[0, 1], ends[1, 1], m + 1))
                     .astype(np.int32)]
    else:
        return np.c_[np.round(np.linspace(ends[0, 0], ends[1, 0], m + 1))
                     .astype(np.int32),
                 np.linspace(ends[0, 1], ends[1, 1], m + 1, dtype=np.int32)]

So the B points set is Batch x 3 x H x W.

Then, I want to compute like this: |Value(A)-Value(B1)|+|Value(A)-Value(B2)|+|Value(A)-Value(B3)|. The size of the result should be Batch x H x W.

Is there any numpy vectorization tricks that can be used to update the value of each pixel in predicted images? Or can this be solved using pytorch functions? I need to find a method to update the whole image. The predicted image is the softmax output. I cannot use for loop to compute each single value since it will become non-differentiable. Thanks a lot.

Upvotes: 0

Views: 1509

Answers (1)

benjaminplanche
benjaminplanche

Reputation: 15159

As suggested by @Matin, you could consider Bresenham's algorithm to get your points on the AC line.

A simplistic PyTorch implementation could be as follows (directly adapted from the pseudo-code here ; could be optimized):

import torch

def get_points_from_low(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dy = dy * yi
    D = 2 * dy - dx

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        x = x + xi
        is_D_gt_0 = (D > 0).long()
        y = y + is_D_gt_0 * yi
        D = D + 2 * dy - is_D_gt_0 * 2 * dx

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from_high(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dx = dx * xi
    D = 2 * dx - dy

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        y = y + yi
        is_D_gt_0 = (D > 0).long()
        x = x + is_D_gt_0 * xi
        D = D + 2 * dx - is_D_gt_0 * 2 * dy

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from(x0, y0, x1, y1, num_points=3):
    is_dy_lt_dx = (torch.abs(y1 - y0) < torch.abs(x1 - x0)).long()
    is_x0_gt_x1 = (x0 > x1).long()
    is_y0_gt_y1 = (y0 > y1).long()

    sign = 1 - 2 * is_x0_gt_x1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_low = get_points_from_low(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points)
    points_low *= sign.view(-1, 1, 1).expand_as(points_low)

    sign = 1 - 2 * is_y0_gt_y1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_high = get_points_from_high(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points) * sign
    points_high *= sign.view(-1, 1, 1).expand_as(points_high)

    is_dy_lt_dx = is_dy_lt_dx.view(-1, 1, 1).expand(-1, num_points, 2)
    points = points_low * is_dy_lt_dx + points_high * (1 - is_dy_lt_dx)

    return points

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
num_points = 3

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)
print(Bs)
# tensor([[[1, 1],
#          [2, 2],
#          [3, 2]],
#         [[7, 6],
#          [6, 5],
#          [5, 5]]])

Once you have your points, you could retrieve their "values" (Value(A), Value(B1), etc.) using torch.index_select() (note that as of now, this method only accept 1D indices, so you need to unravel your data). All things put together, this would look like something such as the following (extending A from shape (Batch, 2) to (Batch, H, W, 2) is left for exercise...)

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
batch_size = A.shape[0]
num_points = 3
map_size = (9, 9)
map_num_elements = map_size[0] * map_size[1]
map_values = torch.stack((torch.arange(0, map_num_elements).view(*map_size),
                          torch.arange(0, -map_num_elements, -1).view(*map_size)))

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([ 1, -4])

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([  1, -78])

# Get map values in positions B:
Bs_flatten = Bs.view(-1, 2)
Bs_unravel = (torch.arange(0, batch_size)
              .unsqueeze(1)
              .repeat(1, num_points)
              .view(num_points * batch_size) * map_num_elements)
Bs_unravel = Bs_unravel + Bs_flatten[:, 0] * map_size[1] + Bs_flatten[:, 1]
values_B = torch.index_select(map_values.view(-1), dim=0, index=Bs_unravel)
values_B = values_B.view(batch_size, num_points)
print(values_B)
# tensor([[ 10,  20,  29],
#         [-69, -59, -50]])

# Compute result:
res = torch.abs(values_A.unsqueeze(-1).expand_as(values_B) - values_B)
print(res)
# tensor([[ 9, 19, 28],
#         [ 9, 19, 28]])
res = torch.sum(res, dim=1)
print(res)
# tensor([56, 56])

Upvotes: 0

Related Questions