Reputation: 306
lets say I have model called UNet
output = UNet(input)
that output is a vector of grayscale images shape: (batch_size,1,128,128)
What I want to do is to normalize each image to be in range [0,1]
.
I did it like this:
for i in range(batch_size):
output[i,:,:,:] = output[i,:,:,:]/torch.amax(output,dim=(1,2,3))[i]
now every image in the output is normalized, but when I'm training such model, pytorch claim it cannot calculate the gradients in this procedure, and I understand why.
my question is what is the right way to normalize image without killing the backpropogation flow? something like
output = UNet(input)
output = output.normalize
output2 = some_model(output)
loss = ..
loss.backward()
optimize.step()
my only option right now is adding a sigmoid activation at the end of the UNet but i dont think its a good idea..
update - code (gen2,disc = unet,discriminator models. est_bias is some output):
update 2x code:
with torch.no_grad():
est_bias_for_disc = gen2(input_img)
est_bias_for_disc /= est_bias_for_disc.amax(dim=(1,2,3), keepdim=True)
disc_fake_hat = disc(est_bias_for_disc.detach())
disc_fake_loss = BCE(disc_fake_hat, torch.zeros_like(disc_fake_hat))
disc_real_hat = disc(bias_ref)
disc_real_loss = BCE(disc_real_hat, torch.ones_like(disc_real_hat))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
if epoch<=epochs_till_gen2_stop:
disc_loss.backward(retain_graph=True) # Update gradients
opt_disc.step() # Update optimizer
then theres seperate training:
opt_gen2.zero_grad()
est_bias = gen2(input_img)
est_bias /= est_bias.amax(dim=(1,2,3), keepdim=True)
disc_fake = disc(est_bias)
ADV_loss = BCE(disc_fake, torch.ones_like(disc_fake))
gen2_loss = ADV_loss
gen2_loss.backward()
opt_gen2.step()
Upvotes: 3
Views: 10019
Reputation: 40648
You are overwriting the tensor's value because of the indexing on the batch dimension. Instead, you can perform the operation in vectorized form:
output = output / output.amax(dim=(1,2,3), keepdim=True)
The keepdim=True
argument keeps the shape of torch.Tensor.amax
's output equal to that of its inputs allowing you to perform an in-place operation with it.
Upvotes: 2
Reputation: 81
You can use the normalize function:
>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.tensor([[3.,4.],[5.,6.],[7.,8.]])
>>> x = F.normalize(x, dim = 0)
>>> print(x)
tensor([[0.3293, 0.3714],
[0.5488, 0.5571],
[0.7683, 0.7428]])
This will give a differentiable tensor as long as out
is not used.
Upvotes: 4