A.E
A.E

Reputation: 1013

copy construct from a tensor: USER WARNING

I am creating a random tensor from normal distribution and since this tensor is served as the weight in the NN, to add requires_grad attributes, I use torch.tensor() as below:

import torch 

input_dim, hidden_dim = 3, 5

norm = torch.distributions.normal.Normal(loc=0, scale=0.01)
W = norm.sample((input_dim, hidden_dim))
W = torch.tensor(W, requires_grad=True)

I am getting user warning error as below:

    UserWarning: To copy construct from a tensor, 
    it is recommended to use sourceTensor.clone().detach() or 
sourceTensor.clone().detach().requires_grad_(True), 
rather than torch.tensor(sourceTensor).

Is there an alternative way to achieve the above? Thanks

Upvotes: 1

Views: 1948

Answers (1)

jodag
jodag

Reputation: 22244

You can just set W.requires_grad to True

import torch 

input_dim, hidden_dim = 3, 5

norm = torch.distributions.normal.Normal(loc=0, scale=0.01)
W = norm.sample((input_dim, hidden_dim))
W.requires_grad = True

Upvotes: 1

Related Questions