Dr. S
Dr. S

Reputation: 41

TypeError: forward() missing 1 required positional argument in a method

I use the following model:

model = DeepGraphInfomax(128, pos_summary_t).to(device)

which looks like:

class DeepGraphInfomax(torch.nn.Module):

def __init__(self, hidden_channels, pos_summary):#, encoder):#, summary, corruption):
    super().__init__()
    self.hidden_channels = hidden_channels
    #self.encoder = GCNEncoder() # needs to be defined either here or give it to here (my comment)
    self.pos_summary = pos_summary
    #self.corruption = corruption

    self.weight = Parameter(torch.Tensor(hidden_channels, hidden_channels))

    #self.reset_parameters()




def forward(self, pos_summary, *args, **kwargs):
  
    #pos_z = self.encoder(*args, **kwargs)
    #cor = self.corruption(*args, **kwargs)
    #cor = cor if isinstance(cor, tuple) else (cor, )
    #neg_z = self.encoder(*cor)
    summary = self.summary(pos_summary, *args, **kwargs)
    return summary# pos_z#, neg_z, summary

but running model() gives me the error:

TypeError: forward() missing 1 required positional argument: 'pos_summary'

Upvotes: 1

Views: 22141

Answers (1)

Ethan Si
Ethan Si

Reputation: 43

model is an object since you instantiated DeepGraphInfomax. model() calls the .__call__ function. forward is called in the .__call__ function i.e. model(). Have a look at here. The TypeError means that you should write input in forward function i.e. model(data).

Here is an exmaple:


import torch
import torch.nn as nn

class example(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(5,2), nn.ReLU(), nn.Linear(2,1), nn.Sigmoid())

    def forward(self, x):
        return self.mlp(x)

# instantiate object
test = example()

input_data = torch.tensor([1.0,2.0,3.0,4.0,5.0])

# () and forward() are equal
print(test(input_data)) 
print(test.forward(input_data))

# outputs for me
#(tensor([0.5387], grad_fn=<SigmoidBackward>),
# tensor([0.5387], grad_fn=<SigmoidBackward>))

Upvotes: 2

Related Questions