Reputation: 1005
I'm programming some callable custom modules in PyTorch and I wanted to know if I'm doing it correctly. Here's an example scenario where I want to construct a module that takes a torch.Tensor
as input, performs a learnable linear operation and outputs a diagonal covariance matrix to use in a multivariate distribution downstream.
class Exp(nn.Module):
def forward(self, x):
return x.exp()
class Diag(nn.Module):
def forward(self, x):
return x.diag_embed()
def init_model(input_size, output_size):
log_variance_module = nn.Linear(input_size, output_size)
diag_covariance_module = nn.Sequential(logvar_module, Exp(), Diag())
return diag_covariance_module
model = init_model(5, 5)
cov = model(some_input_tensor)
dist = MultivariateNormal(some_mean, cov)
I know that this works, but is it the right design pattern? How is one recommended to approach these modules?
Upvotes: 0
Views: 167
Reputation: 40618
This looks like the correct design pattern.
Ideally, you would also write your main network as an nn.Module
:
class Model(nn.Sequential):
def __init__(self, input_size, output_size):
logvar_module = nn.Linear(input_size, output_size)
super().__init__(logvar_module, Exp(), Diag())
Upvotes: 1