Reputation: 434
I am building up a cascade of neural networks and I would like to backpropagate the main loss back to the DNNs and also compute an auxillary loss back to each DNN.
I am trying to figure out what is the best practice when building such a model and how to make sure that my losses are computed properly. Do I build a single torch.nn.Module
and a single optimizer, or do I have to create separate modules and optimizers for each network? Also I am likely to have more than three cascaded DNNs.
Approach a)
import torch
from torch import nn, optim
class MasterNetwork(nn.Module):
def init(self):
super(MasterNetwork, self).__init__()
dnn1 = nn.ModuleList()
dnn2 = nn.ModuleList()
dnn3 = nn.ModuleList()
def forward(self, x, z1, z2):
out1 = dnn1(x)
out2 = dnn2(out1 + z1)
out3 = dnn3(out2 + z2)
return [out1, out2, out3]
def LossFunction(in):
# do stuff
return loss # loss is a scalar value
def ac_loss_1_fn(in):
# do stuff
return loss # loss is a scalar value
def ac_loss_2_fn(in):
# do stuff
return loss # loss is a scalar value
def ac_loss_3_fn(in):
# do stuff
return loss # loss is a scalar value
model = MasterNetwork()
optimizer = optim.Adam(model.parameters())
input = torch.tensor()
z1 = torch.tensor()
z2 = torch.tensor()
outputs = model(input, z1, z2)
main_loss = LossFunction(outputs[2])
ac1_loss = ac_loss_1_fn(outputs[0])
ac2_loss = ac_loss_2_fn(outputs[1])
ac3_loss = ac_loss_3_fn(outputs[2])
optimizer.zero_grad()
'''
This is where I am uncertain about how to backpropagate the AC losses for each DNN
in addition to the main loss.
'''
optimizer.step()
Approach b)
This would creating a nn.Module
class and optimizer for each DNN and then forwarding the loss to the next DNN.
I would prefer to have a solution for approach a) since it is less tedious and I don't have to deal with tuning multiple optimizers. However, I am not sure if this is possible. There was a similar question about backpropagating multiple losses, however, I was not able to understand how combining the losses would work for the distinct components.
Upvotes: 0
Views: 2807
Reputation: 21
the solution you are looking for is likely to use some form of the following:
y = torch.tensor([main_loss, ac1_loss, ac2_loss, ac3_loss])
y.backward(gradient=torch.tensor([1.0,1.0,1.0,1.0]))
See https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients for confirmation.
A similar question exists but this one uses a different phrasing and was the question which I found first when hitting the issue. The similar question can be found at Pytorch. Can autograd be used when the final tensor has more than a single value in it?
Upvotes: 1