Reputation: 3288
I made a module, that needs an extra loss term, e.g.
class MyModule:
def forward(self, x):
out = f(x)
extra_loss = loss_f(self.parameters(), x)
return out, extra_loss
I can't figure out how to make this module embeddable, for example, into a Sequential
model: any regular module like Linear
put after this one will fail because extra_loss
causes the input to Linear
to be a tuple, which Linear
does not support.
So what I am looking for is extracting that extra loss after running the model forward
my_module = MyModule()
model = Sequential(
my_module,
Linear(my_module_outputs, 1)
)
output = model(x)
my_module_loss = ????
loss = mse(label, output) + my_module_loss
Does module composability support this scenario?
Upvotes: 1
Views: 845
Reputation: 1690
IMHO, hooks here is overreaction. Provided extra_loss
is additive, we can use global variable like this:
class MyModule:
extra_loss =0
def forward(self, x):
out = f(x)
MyModule.extra_loss += loss_f(self.parameters(), x)
return out
output = model(x)
loss = mse(label, output) + MyModule.extra_loss
MyModule.extra_loss =0
Upvotes: 1
Reputation: 30936
You can register a hook in this case. A hook can be registered on a Tensor or a nn.Module
. A hook is a function that is executed when the either forward or backward is called. In this case, we want to attach a forward hook without deattaching itself from the graph so that backward pass can happen.
import torch.nn as nn
act_out = {}
def get_hook(name):
def hook(m, input, output):
act_out[name] = output
return hook
class MyModule(torch.nn.Module):
def __init__(self, input, out, device=None):
super().__init__()
self.model = nn.Linear(input,out)
def forward(self,x):
return self.model(x), torch.sum(x) #our extra loss
class MyModule1(torch.nn.Module):
def __init__(self, input, out, device=None):
super().__init__()
self.model = nn.Linear(input,out)
def forward(self, pair):
x, loss = pair
return self.model(x)
model = nn.Sequential(
MyModule(5,10),
MyModule1(10,1)
)
for name, module in model.named_children():
print(name, module)
if name == '0':
module.register_forward_hook(get_hook(name))
x = torch.tensor([1,2,3,4,5]).float()
out = model(x)
print(act_out)
loss = myanotherloss(out)+act_out['0'][1] # this is the extra loss
# further processing
Note: I am using name == '0'
because this is the only module where I want to attach the hook.
Note: Another notable point is nn.Sequential
doesn't allow multiple inputs. In this case, it is simply considered as a tuple and then from that tuple we are using the loss
and the input
.
Upvotes: 1