Reputation: 101
My model trains perfectly fine, but when I switch it to evaluation mode it does not like the data types of the input samples:
Traceback (most recent call last):
File "model.py", line 558, in <module>
main_function(train_sequicity=args.train)
File "model.py", line 542, in main_function
out = model(user, bspan, response_, degree)
File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "model.py", line 336, in forward
self.params['bspan_size'])
File "model.py", line 283, in _greedy_decode_output
out = decoder(input_, encoder_output)
File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "model.py", line 142, in forward
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
RuntimeError: Expected object of scalar type Long but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'
This seems to occur in a part of the code where concatenation happens. This is in an architecture similar to the pytorch transformer, just modified to have two decoders:
def forward(self, tgt, memory):
""" Call decoder
the decoder should be called repeatedly
Args:
tgt: input to transformer_decoder, shape: (seq, batch)
memory: output from the encoder
Returns:
output from linear layer, (vocab size), pre softmax
"""
go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3 # GO_2 token has index 3
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
+
mask = tgt.eq(0).transpose(0,1) # 0 corresponds to <pad>
tgt = self.embedding(tgt) * self.ninp
tgt = self.pos_encoder(tgt)
tgt_mask = self._generate_square_subsequent_mask(tgt.size(0))
output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
output = self.linear(output)
return output
The concatenation bit in the middle of the codeblock is where the problem happens. The odd thing is that it works perfectly fine and trains, with loss going down in train mode. This issue only comes up in eval mode. What could the problem be?
Upvotes: 0
Views: 551
Reputation: 13601
The errors seems to be clear: tgt
is Float
, but it was expecting it to be Long
. Why?
In your code, you define that go_tokens
is torch.int64
(i.e., Long
):
def forward(self, tgt, memory):
go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3 # GO_2 token has index 3
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
# [...]
You can avoid that error by saying that go_tokens
should have the same data type as tgt
:
def forward(self, tgt, memory):
go_tokens = torch.zeros((1, tgt.size(1)), dtype=tgt.dtype) + 3 # GO_2 token has index 3
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
# [...]
Now, if the rest of the code relies on tgt
being torch.int64
, then you should identify why tgt
is torch.int64
at training time and torch.float32
at test time, otherwise another error will be thrown.
Upvotes: 1