fmj.SO
fmj.SO

Reputation: 19

Rare case with: mat1 and mat2 shapes cannot be multiplied

self.model = DQNetwork(11, 256, 3)
class DQNetwork(nn.Module):


    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)


    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

Traceback (most recent call last): File "E:/Work/Programming/PyArk/main.py", line 32, in <module>
agent.train() File "E:\Work\Programming\PyArk\Agent\agent.py", line 31, in train
self.step(states, actions, rewards, next_states, dones) File "E:\Work\Programming\PyArk\Agent\agent.py", line 20, in step
self._trainer.train(state, action, reward, next_state, done)
File "E:\Work\Programming\PyArk\Agent\DQN\dqn_trainer.py", line 32, in train
prediction = self.model(state) File "E:\Work\Programming\PyArk\venv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "E:\Work\Programming\PyArk\Agent\DQN\dqn_network.py", line 19, in forward
x = F.relu(self.linear1(x))
File "E:\Work\Programming\PyArk\venv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "E:\Work\Programming\PyArk\venv\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) File "E:\Work\Programming\PyArk\venv\lib\site-packages\torch\nn\functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (11x5 and 11x256)

I don't understand why this error is popping out

I use the same code in other projects... what is going on..?

Upvotes: 1

Views: 122

Answers (1)

Alexey Birukov
Alexey Birukov

Reputation: 1660

model( torch.zeros(11,5) )  -->   model( torch.zeros(5,11) )

Upvotes: 0

Related Questions