Lee Ching-Chan
Lee Ching-Chan

Reputation: 35

Code Optimization: Computation in Torch.Tensor

I am currently implementing a function to compute Custom Cross Entropy Loss. The definition of the function is a following image.

cite from "Deep Ordinal Regression Network for Monocular Depth Estimation", Huan Fu et al. CVPR 2018

my codes are as following,

output = output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)

batch, height, width, channel = output.size()

total_loss = 0.
for b in range(batch): # for each batch
    o = output[b]
    t = target[b]
    loss = 0.
    for w in range(width):
        for h in range(height): # for every pixel([h,w]) in the image
            sid_t = t[h][w][0]
            sid_o_candi = o[h][w]
            part1 = 0. # to store the first sigma 
            part2 = 0. # to store the second sigma

            for k in range(0, sid_t):
                p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
                part1 += torch.log(p + 1e-12).item()

            for k in range(sid_t, intervals):
                p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
                part2 += torch.log(1-p + 1e-12).item()

            loss += part1 + part2

    loss /= width * height * (-1)
    total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)

I am wondering is there any optimization could be done with these code.

Upvotes: 0

Views: 58

Answers (1)

David Ng
David Ng

Reputation: 1698

I'm not sure sid_t = t[h][w][0] is the same for every pixel or not. If so, you can get rid of all for loop which boost the speed of computing loss.

Don't use .item() because it will return a Python value which loses the grad_fn track. Then you can't use loss.backward() to compute the gradients.

If sid_t = t[h][w][0] is not the same, here is some modification to help you get rid of at least 1 for-loop:



batch, height, width, channel = output.size()

total_loss = 0.
for b in range(batch): # for each batch
    o = output[b]
    t = target[b]
    loss = 0.
    for w in range(width):
        for h in range(height): # for every pixel([h,w]) in the image
            sid_t = t[h][w][0]
            sid_o_candi = o[h][w]
            part1 = 0. # to store the first sigma 
            part2 = 0. # to store the second sigma

            sid1_cumsum = sid_o_candi[:sid_t].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,)) 
            part1 = torch.sum(torch.log(sid1_cumsum + 1e-12))

            sid2_cumsum = sid_o_candi[sid_t:intervals].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,)) 
            part2 = torch.sum(torch.log(1 - sid2_cumsum + 1e-12))

            loss += part1 + part2

    loss /= width * height * (-1)
    total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)

How it works:

x = torch.arange(10); 
print(x)

x_flip = x.flip(dims=(0,)); 
print(x_flip)

x_inverse_cumsum = x_flip.cumsum(dim=0).flip(dims=(0,))
print(x_inverse_cumsum)

# output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([45, 45, 44, 42, 39, 35, 30, 24, 17,  9])

Hope it helps.

Upvotes: 2

Related Questions