Pietro
Pietro

Reputation: 465

Custom backward/optimization steps in pytorch-lightning

I would like to implement the training loop below in pytorch-lightning (to be read as pseudo-code). The peculiarity is that the backward and optimization steps are not performed for every batch.

(Background: I am trying to implement a few-shots learning algorithm; although I need to make predictions at every step -- forward method -- I need to perform the gradient updates at random -- if- block.

for batch in batches:
    x, y = batch
    loss = forward(x,y)

    optimizer.zero_grad()

    if np.random.rand() > 0.5:
        loss.backward()
        optimizer.step()

My proposed solution entails implementing the backward and the optimizer_step methods as follows:

def backward(self, use_amp, loss, optimizer):
        self.compute_grads = False
        if np.random.rand() > 0.5:
            loss.backward()
            nn.utils.clip_grad_value_(self.enc.parameters(), 1)
            nn.utils.clip_grad_value_(self.dec.parameters(), 1)
            self.compute_grads = True
        return


    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
        if self.compute_grads:
            optimizer.step()
            optimizer.zero_grad()   
        return

Note: In this way I need to store a compute_grads attribute at the class level.

What is the "best-practice" way to implement it in pytorch-lightning? Is there a better way to use the hooks?

Upvotes: 0

Views: 4034

Answers (1)

xela
xela

Reputation: 173

This is a good way to do it! that's what the hooks are for.

There is a new Callbacks module that might also be helpful: https://pytorch-lightning.readthedocs.io/en/0.7.1/callbacks.html

Upvotes: 1

Related Questions