Lugi
Lugi

Reputation: 593

PyTorch - a functional equivalent of nn.Module

As we know we can wrap arbitrary number of stateful building blocks into a class which inherits from nn.Module. But how is it supposed to be done when you want to wrap a bunch of stateless functions (from nn.Functional), in order to fully utilize things which nn.Module allows you to, like automatic moving of tensors between CPU and GPU with just model.to(device)?

Upvotes: 0

Views: 147

Answers (1)

Lugi
Lugi

Reputation: 593

I already found the solution: if you have an operation inside of a module which creates a new tensor, then you have to use self.register_buffer in order to fully utilize automating moving between devices.

Upvotes: 1

Related Questions