allen zhong
allen zhong

Reputation: 53

In pytorch, what situations the loss function need to inherit nn.module?

I am confused about the loss function in PyTorch. Some people define the loss function as a normal python function while others define the loss function by defining a class that inherits nn.Module. So I want to know what situations we need to define the loss function by inheriting nn.Module? Many thanks.

Upvotes: 5

Views: 974

Answers (1)

Crystina
Crystina

Reputation: 1230

Generally, inheritance from nn.Module is only necessary when you want to have trainable variables in this module, otherwise it's optional to inherit it.

So same applies to loss functions, if it contains no such variables (which I assume is the major case), no inheritance is needed.

Upvotes: 5

Related Questions