t-smart
t-smart

Reputation: 187

How can I limit the range of parameters in pytorch?

So normally in pytorch, there is no strict limit to the parameters in models, but what if I wanted them to stay in the range [0,1]? Is there a way to block the update of parameters to outside that range?

Upvotes: 6

Views: 4995

Answers (1)

Kevin Alex Zhang
Kevin Alex Zhang

Reputation: 311

A trick used in some generative adversarial networks (some of which require the parameters of the discriminator to be within a certain range) is to clamp the values after every gradient update. For example:

model = YourPyTorchModule()

for _ in range(epochs):
    loss = ...
    optimizer.step()
    for p in model.parameters():
        p.data.clamp_(0.0, 1.0)

Upvotes: 11

Related Questions