Aryan Philip
Aryan Philip

Reputation: 50

I am running into a gradient computation inplace error

I am running this code (https://github.com/ayu-22/BPPNet-Back-Projected-Pyramid-Network/blob/master/Single_Image_Dehazing.ipynb) on a custom dataset but I am running into this error. RuntimeError: one of the variables needed for gradient computation has been modified by an in place operation: [torch. cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Error Message

Please refer to the code link above for clarification of where the error is occurring.

I am running this model on a custom dataset, the data loader part is pasted below.

    import torchvision.transforms as transforms
    train_transform = transforms.Compose([
    transforms.Resize((256,256)),
    #transforms.RandomResizedCrop(256),
    #transforms.RandomHorizontalFlip(),
    #transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
 ])

class Flare(Dataset):
  def __init__(self, flare_dir, wf_dir,transform = None):
    self.flare_dir = flare_dir
    self.wf_dir = wf_dir
    self.transform = transform
    self.flare_img = os.listdir(flare_dir)
    self.wf_img = os.listdir(wf_dir)
    
  def __len__(self):
     return len(self.flare_img)
  def __getitem__(self, idx):
    f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
    for i in self.wf_img:
        if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
            wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
            break
    f_img = self.transform(f_img)
    wf_img = self.transform(wf_img)
    
   return f_img, wf_img         





flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
                                       batch_size=BATCH_SIZE, 
                                       shuffle=True)

To get a better idea of the dataset class , you can compare my dataset class with the link pasted above

Upvotes: 0

Views: 804

Answers (1)

Satya Prakash Dash
Satya Prakash Dash

Reputation: 1216

Your code is stuck in what is called the "Backpropagation" of your GAN Network.

What you have defined your backward graph should follow is the following:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        self.dis_optimizer.step()

        unet_loss.backward()
        self.unet_optimizer.step()

So in your backward graph, you are propagating the dis_loss which is the combination of the discriminator and adversarial loss first and then you are propagating the unet_loss which is the combination of UNet, SSIM and ContentLoss but the unet_loss is connected to discriminator's output loss. So the pytorch is confused and gives you this error as you are taking the optimizer step of dis_loss before even storing the backward graph for unet_loss and I would recommend you to change the code as follows:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        unet_loss.backward()

        self.dis_optimizer.step()
        self.unet_optimizer.step()

And this will start your training! but you can experiment with your retain_graph=True.

And great work on the BPPNet Work.

Upvotes: 1

Related Questions