Memduh
Memduh

Reputation: 101

Pytorch crashes on input in eval mode

My model trains perfectly fine, but when I switch it to evaluation mode it does not like the data types of the input samples:

Traceback (most recent call last):
  File "model.py", line 558, in <module>
    main_function(train_sequicity=args.train)
  File "model.py", line 542, in main_function
    out = model(user, bspan, response_, degree)
  File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "model.py", line 336, in forward
    self.params['bspan_size'])
  File "model.py", line 283, in _greedy_decode_output
    out = decoder(input_, encoder_output)
  File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "model.py", line 142, in forward
    tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis
RuntimeError: Expected object of scalar type Long but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'

This seems to occur in a part of the code where concatenation happens. This is in an architecture similar to the pytorch transformer, just modified to have two decoders:

      def forward(self, tgt, memory):
          """ Call decoder
          the decoder should be called repeatedly

          Args:
              tgt: input to transformer_decoder, shape: (seq, batch)
              memory: output from the encoder

          Returns:
              output from linear layer, (vocab size), pre softmax

          """
          go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3  # GO_2 token has index 3

          tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis

+
          mask = tgt.eq(0).transpose(0,1)  # 0 corresponds to <pad>
          tgt = self.embedding(tgt) * self.ninp
          tgt = self.pos_encoder(tgt)
          tgt_mask = self._generate_square_subsequent_mask(tgt.size(0))
          output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
          output = self.linear(output)
          return output

The concatenation bit in the middle of the codeblock is where the problem happens. The odd thing is that it works perfectly fine and trains, with loss going down in train mode. This issue only comes up in eval mode. What could the problem be?

Upvotes: 0

Views: 551

Answers (1)

Berriel
Berriel

Reputation: 13601

The errors seems to be clear: tgt is Float, but it was expecting it to be Long. Why?

In your code, you define that go_tokens is torch.int64 (i.e., Long):

def forward(self, tgt, memory):
    go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3  # GO_2 token has index 3
    tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis
    # [...]

You can avoid that error by saying that go_tokens should have the same data type as tgt:

def forward(self, tgt, memory):
    go_tokens = torch.zeros((1, tgt.size(1)), dtype=tgt.dtype) + 3  # GO_2 token has index 3
    tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis
    # [...]

Now, if the rest of the code relies on tgt being torch.int64, then you should identify why tgt is torch.int64 at training time and torch.float32 at test time, otherwise another error will be thrown.

Upvotes: 1

Related Questions