Reputation: 682
I do not understand where this error comes from, the number of arguments to the model seems correct, below is my model:
class MancalaModel(nn.Module):
def __init__(self, n_inputs=16, n_outputs=16):
super().__init__()
n_neurons = 256
def create_block(n_in, n_out):
block = nn.ModuleList()
block.append(nn.Linear(n_in, n_out))
block.append(nn.ReLU())
return block
self.blocks = nn.ModuleList()
self.blocks.append(create_block(n_inputs, n_neurons))
for _ in range(6):
self.blocks.append(create_block(n_neurons, n_neurons))
self.actor_block = nn.ModuleList()
self.critic_block = nn.ModuleList()
for _ in range(2):
self.actor_block.append(create_block(n_neurons, n_neurons))
self.critic_block.append(create_block(n_neurons, n_neurons))
self.actor_block.append(create_block(n_neurons, n_outputs))
self.critic_block.append(create_block(n_neurons, 1))
self.apply(init_weights)
def forward(self, x):
x = self.blocks(x)
actor = F.softmax(self.actor_block(x))
critics = self.critic_block(x)
return actor, critics
Then I create an instance and make a forward pass with random number
model = MancalaModel()
x = model(torch.rand(1, 16))
Then I got the TypeError saying the number of arguments is not correct:
2 model = MancalaModel()
----> 3 x = model(torch.rand(1, 16))
4 # summary(model, (16,), device='cpu')
5
d:\environments\python\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
D:\UOM\Year3\AI & Games\KalahPlayer\agents\model_agent.py in forward(self, x)
54
55 def forward(self, x):
---> 56 x = self.blocks(x)
57 actor = F.softmax(self.actor_block(x))
58 critics = self.critic_block(x)
d:\environments\python\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
TypeError: forward() takes 1 positional argument but 2 were given
Any help is appreciated, thanks!
Upvotes: 2
Views: 2706
Reputation: 114936
TL;DR
You are trying to forward
through nn.ModuleList
- this is not defined.
You need to convert self.blocks
to nn.Sequential
:
def create_block(n_in, n_out):
# do not work with ModuleList here either.
block = nn.Sequential(
nn.Linear(n_in, n_out),
nn.ReLU()
)
return block
blocks = [] # simple list - not a member of self, for temporal use only.
blocks.append(create_block(n_inputs, n_neurons))
for _ in range(6):
blocks.append(create_block(n_neurons, n_neurons))
self.blocks = nn.Sequential(*blocks) # convert the simple list to nn.Sequential
I was expecting you to get NotImplementedError
, and not this TypeError
, because your self.blocks
is of type nn.ModuleList
and its forward
method throws NotImplementedError
. I just made a pull request to fix this confusing issue.
Update (April 22nd, 2021): the PR was merged. In future versions you should see NotImplementedError
when calling nn.ModuleList
or nn.ModuleDict
.
Upvotes: 4