LOST
LOST

Reputation: 3288

How to return extra loss from module forward function in PyTorch?

I made a module, that needs an extra loss term, e.g.

class MyModule:
  def forward(self, x):
    out = f(x)
    extra_loss = loss_f(self.parameters(), x)
    return out, extra_loss

I can't figure out how to make this module embeddable, for example, into a Sequential model: any regular module like Linear put after this one will fail because extra_loss causes the input to Linear to be a tuple, which Linear does not support.

So what I am looking for is extracting that extra loss after running the model forward

my_module = MyModule()

model = Sequential(
  my_module,
  Linear(my_module_outputs, 1)
)

output = model(x)
my_module_loss = ????
loss = mse(label, output) + my_module_loss

Does module composability support this scenario?

Upvotes: 1

Views: 845

Answers (2)

Alexey Birukov
Alexey Birukov

Reputation: 1690

IMHO, hooks here is overreaction. Provided extra_loss is additive, we can use global variable like this:

class MyModule:
    extra_loss =0
    def forward(self, x):
        out = f(x)
        MyModule.extra_loss += loss_f(self.parameters(), x)
        return out

output = model(x)
loss = mse(label, output) + MyModule.extra_loss
MyModule.extra_loss =0

Upvotes: 1

user2736738
user2736738

Reputation: 30936

You can register a hook in this case. A hook can be registered on a Tensor or a nn.Module. A hook is a function that is executed when the either forward or backward is called. In this case, we want to attach a forward hook without deattaching itself from the graph so that backward pass can happen.

import torch.nn as nn
act_out = {}
def get_hook(name):
    def hook(m, input, output):
        act_out[name] = output
    return hook

class MyModule(torch.nn.Module):
  def __init__(self, input, out, device=None):
    super().__init__()
    self.model = nn.Linear(input,out)
  def forward(self,x):
    return self.model(x), torch.sum(x) #our extra loss


class MyModule1(torch.nn.Module):
  def __init__(self, input, out, device=None):
    super().__init__()
    self.model = nn.Linear(input,out)
  def forward(self, pair):
    x, loss = pair
    return self.model(x)


model = nn.Sequential(
    MyModule(5,10),
    MyModule1(10,1)
)

for name, module in model.named_children():
    print(name, module)
    if name == '0':
        module.register_forward_hook(get_hook(name))

x = torch.tensor([1,2,3,4,5]).float()
out = model(x)

print(act_out)
loss = myanotherloss(out)+act_out['0'][1] # this is the extra loss
# further processing

Note: I am using name == '0' because this is the only module where I want to attach the hook.

Note: Another notable point is nn.Sequential doesn't allow multiple inputs. In this case, it is simply considered as a tuple and then from that tuple we are using the loss and the input.

Upvotes: 1

Related Questions