isle_of_gods
isle_of_gods

Reputation: 1005

PyTorch Tensor methods to nn.Modules?

I'm programming some callable custom modules in PyTorch and I wanted to know if I'm doing it correctly. Here's an example scenario where I want to construct a module that takes a torch.Tensor as input, performs a learnable linear operation and outputs a diagonal covariance matrix to use in a multivariate distribution downstream.

class Exp(nn.Module):
    def forward(self, x):
        return x.exp()

class Diag(nn.Module):
    def forward(self, x):
        return x.diag_embed()

def init_model(input_size, output_size):
    log_variance_module = nn.Linear(input_size, output_size)
    diag_covariance_module = nn.Sequential(logvar_module, Exp(), Diag())
    return diag_covariance_module

model = init_model(5, 5)
cov = model(some_input_tensor)
dist = MultivariateNormal(some_mean, cov)

I know that this works, but is it the right design pattern? How is one recommended to approach these modules?

Upvotes: 0

Views: 167

Answers (1)

Ivan
Ivan

Reputation: 40618

This looks like the correct design pattern.
Ideally, you would also write your main network as an nn.Module:

class Model(nn.Sequential):
    def __init__(self, input_size, output_size):
        logvar_module = nn.Linear(input_size, output_size)
        super().__init__(logvar_module, Exp(), Diag())

Upvotes: 1

Related Questions