Reputation: 917
I'm trying to understand the design of pytorch a little bit better. I was trying to draw samples from a multivariate normal, and found torch.distributions.multivariate_normal, which to my surprise is a module with many protected functions defined outside of its MultivariateNormal() class.
I was confused as to why this was the case. Why not just define all of these functions as class methods inside the MultivariateNormal() class? That way, we could instantiate an object of this class by
torch.distributions.multivariate_normal(mu,sigma)
rather than
torch.distributions.multivariate_normal.MultivariateNormal(mu,sigma).
Any thoughts?
Thanks.
Upvotes: 0
Views: 1908
Reputation: 54
You can call MultivariateNormal directly:
import torch
gaussian = torch.distributions.MultivariateNormal(torch.ones(2),torch.eye(2))
But the class MultivariateNormal is implemented in the file "torch/distributions/multivariate_normal.py", so both calls are correct
Upvotes: 2