klkh
klkh

Reputation: 285

In PyTorch, what is the difference between forward() and an ordinary method?

How is implementing the forward() method of a custom nn.Module class different from adding an ordinary method to that class?

I heard that the forward() method should only accept and return tensors, because PyTorch has implemented special processing on the input and output of the forward() method. But I tried inputting/outputting non-tensor objects on a forward() method, and implementing a module that doesn't have a forward() method (instead, there are multiple custom-named methods which act like forward() methods). Both ways worked well.

Upvotes: 5

Views: 2844

Answers (1)

Wasi Ahmad
Wasi Ahmad

Reputation: 37691

forward() method does accept any type of parameters. However, the goal of the forward() method is to encapsulate the forward computational steps. forward() is called in the __call__ function. In the forward() method, PyTorch call the nested model itself to perform the forward pass.

It is encouraged to:

NOT call the forward(x) method. You should call the whole model itself, as in model(x) to perform a forward pass and output predictions.

What happens if you do not do that?

If you call the .forward() method, and have hooks in your model, the hooks won’t have any effect.

Upvotes: 7

Related Questions