user49404
user49404

Reputation: 917

PyTorch Design: Why does torch.distributions.multivariate_normal have methods outside of its class?

I'm trying to understand the design of pytorch a little bit better. I was trying to draw samples from a multivariate normal, and found torch.distributions.multivariate_normal, which to my surprise is a module with many protected functions defined outside of its MultivariateNormal() class.

I was confused as to why this was the case. Why not just define all of these functions as class methods inside the MultivariateNormal() class? That way, we could instantiate an object of this class by

torch.distributions.multivariate_normal(mu,sigma)

rather than

torch.distributions.multivariate_normal.MultivariateNormal(mu,sigma).

Any thoughts?

Thanks.

Upvotes: 0

Views: 1908

Answers (1)

aghriss
aghriss

Reputation: 54

You can call MultivariateNormal directly:

import torch
gaussian = torch.distributions.MultivariateNormal(torch.ones(2),torch.eye(2))

But the class MultivariateNormal is implemented in the file "torch/distributions/multivariate_normal.py", so both calls are correct

Upvotes: 2

Related Questions