Penguin
Penguin

Reputation: 2431

How to call "backward" in a loop with 2 optimizers?

I have 2 networks that I'm trying to update:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
from tqdm import tqdm

softplus = torch.nn.Softplus()

class Model_RL(nn.Module):
    def __init__(self):
        super(Model_RL, self).__init__()
        self.fc1 = nn.Linear(3, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = softplus(self.fc3(x))
        return x

class Model_FA(nn.Module):
    def __init__(self):
        super(Model_FA, self).__init__()
        self.fc1 = nn.Linear(1, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = softplus(self.fc3(x))
        return x

net_RL = Model_RL()
net_FA = Model_FA()

The training loop is

inps = torch.tensor([[1.0]])
y = torch.tensor(10.0)

opt_RL = optim.Adam(net_RL.parameters())
opt_FA = optim.Adam(net_FA.parameters()) 

baseline = 0
baseline_lr = 0.1

epochs = 100

for _ in tqdm(range(epochs)):

    for inp in inps:

        with torch.no_grad():
            net_FA(inp)
               
        for layer in range(3):
            out_RL = net_RL(torch.tensor([1.0,2.0,3.0]))
            mu, std = out_RL
            dist = Normal(mu, std)
            update_values = dist.sample() 
            log_p = dist.log_prob(update_values).mean()

            out = net_FA(inp) 
            reward = -torch.square((y - out)) 
            baseline = (1 - baseline_lr) * baseline + baseline_lr * reward

            loss_RL = - (reward - baseline) * log_p            
            opt_RL.zero_grad()
            opt_FA.zero_grad()
            loss_RL.backward()
            opt_RL.step()            

            out = net_FA(inp) 
            loss_FA = torch.mean(torch.square(y - out)) 
            opt_RL.zero_grad()
            opt_FA.zero_grad()
            loss_FA.backward()
            opt_FA.step()



print("Mean: " + str(mu.detach().numpy()) + ", Goal: " + str(y))
print("Standard deviation: " + str(softplus(std).detach().numpy()) + ", Goal: 0ish")    

I'm getting 2 main errors:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward()...

And when I add retain_graph=True to both backward calls I get the following

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [30, 1]], which is output 0 of TBackward, is at version 5; expected version 4 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)

My main question is how can I make this training work?

But intermediate questions are:

why does retain_graph=True is needed here if I'm using a loop? From here: "there is no need to use retain_graph=True. In each loop, a new graph is created"

Why does it seem as if the retain_graph=True makes training significantly slower (if I remove the other backward call)? This doesn't really makes sense to me as in each epoch a new computational graph should be created (and not just one that is being extended).

Upvotes: 2

Views: 520

Answers (1)

Girish Hegde
Girish Hegde

Reputation: 1515

I think the line baseline = (1 - baseline_lr) * baseline + baseline_lr * reward causing the error. Because:

  • previous state of baseline is used to get new state of baseline.
  • PyTorch will track all these states inside a graph.
  • backward will flush the graph.
  • variable baseline of time - t + 1 will try to backpropagate through baseline of time - t.
  • But at time - t + 1 graph behind baseline of time - t doesn't exist.
  • This leads to error

Solution: As you are not optimizing variable baseline or anything behind baseline

  • Initializebaseline as torch tensor.
  • detach it from graph before updating state.

Try this:

# intialize baseline as torch tensor
baseline = torch.tensor(0.)
baseline_lr = 0.1

epochs = 100

for _ in tqdm(range(epochs)):

    for inp in inps:

        with torch.no_grad():
            net_FA(inp)
               
        for layer in range(3):
            out_RL = net_RL(torch.tensor([1.0,2.0,3.0]))
            mu, std = out_RL
            dist = Normal(mu, std)
            update_values = dist.sample() 
            log_p = dist.log_prob(update_values).mean()

            out = net_FA(inp) 
            reward = -torch.square((y - out)) 

            # detach baseline from graph
            baseline = (1 - baseline_lr) * baseline.detach() + baseline_lr * reward

            loss_RL = - (reward - baseline) * log_p            
            opt_RL.zero_grad()
            opt_FA.zero_grad()
            loss_RL.backward()
            opt_RL.step()            

            out = net_FA(inp) 
            loss_FA = torch.mean(torch.square(y - out)) 
            opt_RL.zero_grad()
            opt_FA.zero_grad()
            loss_FA.backward()
            opt_FA.step()

But actually I don't know why you are updating the networks, 3 times for the same input?

Upvotes: 3

Related Questions