Reputation: 465
I would like to implement the training loop below in pytorch-lightning (to be read as pseudo-code). The peculiarity is that the backward and optimization steps are not performed for every batch.
(Background: I am trying to implement a few-shots learning algorithm; although I need to make predictions at every step -- forward
method
-- I need to perform the gradient updates at random -- if
- block.
for batch in batches:
x, y = batch
loss = forward(x,y)
optimizer.zero_grad()
if np.random.rand() > 0.5:
loss.backward()
optimizer.step()
My proposed solution entails implementing the backward
and the optimizer_step
methods as follows:
def backward(self, use_amp, loss, optimizer):
self.compute_grads = False
if np.random.rand() > 0.5:
loss.backward()
nn.utils.clip_grad_value_(self.enc.parameters(), 1)
nn.utils.clip_grad_value_(self.dec.parameters(), 1)
self.compute_grads = True
return
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
if self.compute_grads:
optimizer.step()
optimizer.zero_grad()
return
Note: In this way I need to store a compute_grads
attribute at the class level.
What is the "best-practice" way to implement it in pytorch-lightning? Is there a better way to use the hooks?
Upvotes: 0
Views: 4034
Reputation: 173
This is a good way to do it! that's what the hooks are for.
There is a new Callbacks module that might also be helpful: https://pytorch-lightning.readthedocs.io/en/0.7.1/callbacks.html
Upvotes: 1