Reputation: 35
I am currently implementing a function to compute Custom Cross Entropy Loss. The definition of the function is a following image.
my codes are as following,
output = output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)
batch, height, width, channel = output.size()
total_loss = 0.
for b in range(batch): # for each batch
o = output[b]
t = target[b]
loss = 0.
for w in range(width):
for h in range(height): # for every pixel([h,w]) in the image
sid_t = t[h][w][0]
sid_o_candi = o[h][w]
part1 = 0. # to store the first sigma
part2 = 0. # to store the second sigma
for k in range(0, sid_t):
p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
part1 += torch.log(p + 1e-12).item()
for k in range(sid_t, intervals):
p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
part2 += torch.log(1-p + 1e-12).item()
loss += part1 + part2
loss /= width * height * (-1)
total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)
I am wondering is there any optimization could be done with these code.
Upvotes: 0
Views: 58
Reputation: 1698
I'm not sure sid_t = t[h][w][0]
is the same for every pixel or not. If so, you can get rid of all for loop
which boost the speed of computing loss.
Don't use
.item()
because it will return a Python value which loses thegrad_fn
track. Then you can't useloss.backward()
to compute the gradients.
If sid_t = t[h][w][0]
is not the same, here is some modification to help you get rid of at least 1 for-loop
:
batch, height, width, channel = output.size()
total_loss = 0.
for b in range(batch): # for each batch
o = output[b]
t = target[b]
loss = 0.
for w in range(width):
for h in range(height): # for every pixel([h,w]) in the image
sid_t = t[h][w][0]
sid_o_candi = o[h][w]
part1 = 0. # to store the first sigma
part2 = 0. # to store the second sigma
sid1_cumsum = sid_o_candi[:sid_t].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,))
part1 = torch.sum(torch.log(sid1_cumsum + 1e-12))
sid2_cumsum = sid_o_candi[sid_t:intervals].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,))
part2 = torch.sum(torch.log(1 - sid2_cumsum + 1e-12))
loss += part1 + part2
loss /= width * height * (-1)
total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)
How it works:
x = torch.arange(10);
print(x)
x_flip = x.flip(dims=(0,));
print(x_flip)
x_inverse_cumsum = x_flip.cumsum(dim=0).flip(dims=(0,))
print(x_inverse_cumsum)
# output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([45, 45, 44, 42, 39, 35, 30, 24, 17, 9])
Hope it helps.
Upvotes: 2