Adar Cohen
Adar Cohen

Reputation: 306

How to normalize pytorch model output to be in range [0,1]

lets say I have model called UNet

output = UNet(input)

that output is a vector of grayscale images shape: (batch_size,1,128,128)

What I want to do is to normalize each image to be in range [0,1].

I did it like this:

for i in range(batch_size):
   output[i,:,:,:] = output[i,:,:,:]/torch.amax(output,dim=(1,2,3))[i]

now every image in the output is normalized, but when I'm training such model, pytorch claim it cannot calculate the gradients in this procedure, and I understand why.

my question is what is the right way to normalize image without killing the backpropogation flow? something like

output = UNet(input)
output = output.normalize
output2 = some_model(output)
loss = ..
loss.backward()
optimize.step()

my only option right now is adding a sigmoid activation at the end of the UNet but i dont think its a good idea..

update - code (gen2,disc = unet,discriminator models. est_bias is some output):


update 2x code:

with torch.no_grad():
            est_bias_for_disc = gen2(input_img)

            est_bias_for_disc /= est_bias_for_disc.amax(dim=(1,2,3), keepdim=True)
        disc_fake_hat = disc(est_bias_for_disc.detach())
        disc_fake_loss = BCE(disc_fake_hat, torch.zeros_like(disc_fake_hat)) 

        disc_real_hat = disc(bias_ref)
        disc_real_loss = BCE(disc_real_hat, torch.ones_like(disc_real_hat))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        if epoch<=epochs_till_gen2_stop:
            disc_loss.backward(retain_graph=True) # Update gradients    
            opt_disc.step() # Update optimizer

then theres seperate training:

 opt_gen2.zero_grad()
 est_bias = gen2(input_img)
 est_bias /= est_bias.amax(dim=(1,2,3), keepdim=True)
 disc_fake = disc(est_bias) 
 ADV_loss = BCE(disc_fake, torch.ones_like(disc_fake))
 gen2_loss = ADV_loss
 gen2_loss.backward() 
 opt_gen2.step()

Upvotes: 3

Views: 10019

Answers (2)

Ivan
Ivan

Reputation: 40648

You are overwriting the tensor's value because of the indexing on the batch dimension. Instead, you can perform the operation in vectorized form:

output = output / output.amax(dim=(1,2,3), keepdim=True)

The keepdim=True argument keeps the shape of torch.Tensor.amax's output equal to that of its inputs allowing you to perform an in-place operation with it.

Upvotes: 2

arizonatea
arizonatea

Reputation: 81

You can use the normalize function:

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.tensor([[3.,4.],[5.,6.],[7.,8.]])
>>> x = F.normalize(x, dim = 0)
>>> print(x)

  tensor([[0.3293, 0.3714],
          [0.5488, 0.5571],
          [0.7683, 0.7428]])

This will give a differentiable tensor as long as out is not used.

Upvotes: 4

Related Questions