Reputation: 41
I use the following model:
model = DeepGraphInfomax(128, pos_summary_t).to(device)
which looks like:
class DeepGraphInfomax(torch.nn.Module):
def __init__(self, hidden_channels, pos_summary):#, encoder):#, summary, corruption): super().__init__() self.hidden_channels = hidden_channels #self.encoder = GCNEncoder() # needs to be defined either here or give it to here (my comment) self.pos_summary = pos_summary #self.corruption = corruption self.weight = Parameter(torch.Tensor(hidden_channels, hidden_channels)) #self.reset_parameters() def forward(self, pos_summary, *args, **kwargs): #pos_z = self.encoder(*args, **kwargs) #cor = self.corruption(*args, **kwargs) #cor = cor if isinstance(cor, tuple) else (cor, ) #neg_z = self.encoder(*cor) summary = self.summary(pos_summary, *args, **kwargs) return summary# pos_z#, neg_z, summary
but running model() gives me the error:
TypeError: forward() missing 1 required positional argument: 'pos_summary'
Upvotes: 1
Views: 22141
Reputation: 43
model
is an object since you instantiated DeepGraphInfomax
.
model()
calls the .__call__
function.
forward
is called in the .__call__
function i.e. model()
.
Have a look at here.
The TypeError
means that you should write input in forward function i.e. model(data)
.
Here is an exmaple:
import torch
import torch.nn as nn
class example(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(5,2), nn.ReLU(), nn.Linear(2,1), nn.Sigmoid())
def forward(self, x):
return self.mlp(x)
# instantiate object
test = example()
input_data = torch.tensor([1.0,2.0,3.0,4.0,5.0])
# () and forward() are equal
print(test(input_data))
print(test.forward(input_data))
# outputs for me
#(tensor([0.5387], grad_fn=<SigmoidBackward>),
# tensor([0.5387], grad_fn=<SigmoidBackward>))
Upvotes: 2