Driss AL
Driss AL

Reputation: 31

train a model which is instantiated in another model ( Pytorch)

I have two classes of networks of neurons one of GNN type and the other simple of linear type, the latter is instantiated in the first !!! how can I train both at the same time? here is an example:

class linear_NN(nn.Module):
  
  def __init__(self, input_dim, out_dim...):
    super().__init__()

  def forward(self, x, dim = 0):
    '''Forward pass'''
    return x

the main class or the large class

class GNN(nn.Module):
  
  def __init__(self, input_dim, n-hidden, out_dim...):
    super().__init__()

  def forward(self, h, dim = 0):
    '''Forward pass'''
    model=linear_NN(input, out..)
    model(h, dim)
    return h

Upvotes: 0

Views: 524

Answers (1)

Berriel
Berriel

Reputation: 13601

You must declare it in the __init__(...):

class GNN(nn.Module):
  def __init__(self, input_dim, n-hidden, out_dim, ...):
    super().__init__()
    self.linear = linear_NN(input, out..)

  def forward(self, h, dim = 0):
    '''Forward pass'''
    self.linear(h, dim)
    return h

Then, the self.linear model will be registered to your GNN main model, and if you get GNN(...).parameters(), you'll see the linear parameters there.

Upvotes: 1

Related Questions