infiNity9819
infiNity9819

Reputation: 588

Pytorch List of all gradients in a model

I'm trying to clip my gradients in a simple deep network model (for RL). But for that I want to fetch statistics of gradients in each epochs, e.g. mean, max etc. Through this I will be able to determine the threshold value to clip my gradients to.

So the way I can approach this was if there was any way to fetch all the calculated gradients as an array after model.backwards() step.

How can I do this? Or is there any other way to determine this hyper-parameter?

Upvotes: 3

Views: 4360

Answers (1)

GoodDeeds
GoodDeeds

Reputation: 8527

You can iterate over the parameters to obtain their gradients. For example,

for param in model.parameters():
    print(param.grad)

The example above just prints the gradient, but you can apply it suitably to compute the information you need.

Upvotes: 4

Related Questions