BlGene
BlGene

Reputation: 17

Transformers PaliGemma evaluate and compute_loss fail with tensors/device errors

I'm loading a PaliGemma2 model google/paligemma2-3b-pt-224 and trying to fine-tune using Trainer/Seq2SeqTrainer. If I add evaluation, this fails. After doing some digging, I found that this only happens if the model is in evaluate mode.

batch = [valid_dataset[i] for i in range(8)]
inputs = collate_fn(batch)
#generate_ids = model.generate(**inputs, max_length=286+30)
trainer.model.train()
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("works")
trainer.model.train(False)
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("fails.")

I've worked around it by mokey-patching compute_loss_context_manager as follows:

orig_context_manager = trainer.compute_loss_context_manager
class TempTrainContext(object):
    def __init__(self, trainer):
        self.trainer = trainer
        self.orig_context_manager = trainer.compute_loss_context_manager
    def __enter__(self):
        self.orig_context_inst = self.orig_context_manager()
        self.orig_context_inst.__enter__()
        self.training_enter = self.trainer.model.training
        self.trainer.model.train()
    def __exit__(self, type, value, traceback):
        self.trainer.model.train(self.training_enter)
        self.orig_context_inst.__exit__(type, value, traceback)
    def __call__(self):
        return self

trainer.compute_loss_context_manager = TempTrainContext(trainer)

(Bonus question: Is this safe to do, or will I train on the test set?)

My versions are:

Python Version: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 16:05:46) [GCC 13.3.0]
Torch Version: 2.5.1+cu124
CUDA Available: True
CUDA Device Count: 2
GPU Name: NVIDIA GeForce RTX 3090
Transformers Version: 4.48.1
Tokenizers Version: 0.21.0
Accelerate Version: 1.3.0

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 8
      6 print("works")
      7 trainer.model.train(False)
----> 8 trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
      9 print("fails.")
     12 orig_context_manager = trainer.compute_loss_context_manager

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3729         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3730     inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
   3732 # Save past state if it exists
   3733 # TODO: this needs to be fixed and made cleaner later.
   3734 if self.args.past_index >= 0:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
    525     labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
    527 causal_mask = self._update_causal_mask(
    528     attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
    529 )
--> 530 outputs = self.language_model(
    531     attention_mask=causal_mask,
    532     position_ids=position_ids,
    533     past_key_values=past_key_values,
    534     inputs_embeds=inputs_embeds,
    535     use_cache=use_cache,
    536     output_attentions=output_attentions,
    537     output_hidden_states=output_hidden_states,
    538     return_dict=return_dict,
    539     cache_position=cache_position,
    540     num_logits_to_keep=num_logits_to_keep,
    541 )
    543 logits = outputs.logits
    544 loss = None

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
    840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
    843     input_ids=input_ids,
    844     attention_mask=attention_mask,
    845     position_ids=position_ids,
    846     past_key_values=past_key_values,
    847     inputs_embeds=inputs_embeds,
    848     use_cache=use_cache,
    849     output_attentions=output_attentions,
    850     output_hidden_states=output_hidden_states,
    851     return_dict=return_dict,
    852     cache_position=cache_position,
    853 )
    855 hidden_states = outputs[0]
    856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
    617     layer_outputs = self._gradient_checkpointing_func(
    618         decoder_layer.__call__,
    619         hidden_states,
   (...)
    626         cache_position,
    627     )
    628 else:
--> 629     layer_outputs = decoder_layer(
    630         hidden_states,
    631         position_embeddings=position_embeddings,
    632         attention_mask=causal_mask,
    633         position_ids=position_ids,
    634         past_key_value=past_key_values,
    635         output_attentions=output_attentions,
    636         use_cache=use_cache,
    637         cache_position=cache_position,
    638         **flash_attn_kwargs,
    639     )
    641 hidden_states = layer_outputs[0]
    643 if output_attentions:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    296 hidden_states = self.input_layernorm(hidden_states)
    298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
    300     hidden_states=hidden_states,
    301     position_embeddings=position_embeddings,
    302     attention_mask=attention_mask,
    303     position_ids=position_ids,
    304     past_key_value=past_key_value,
    305     output_attentions=output_attentions,
    306     use_cache=use_cache,
    307     cache_position=cache_position,
    308 )
    309 hidden_states = self.post_attention_layernorm(hidden_states)
    310 hidden_states = residual + hidden_states

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
    221 if past_key_value is not None:
    222     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    223     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    226 attention_interface: Callable = eager_attention_forward
    227 if self.config._attn_implementation != "eager":

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1714 else:
   1715     update_fn = self._static_update
-> 1717 return update_fn(
   1718     cache_position,
   1719     layer_idx,
   1720     key_states,
   1721     value_states,
   1722     k_out,
   1723     v_out,
   1724     k_out.shape[2],
   1725 )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
   1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694     k_out[:, :, cache_position] = key_states
   1695     v_out[:, :, cache_position] = value_states
   1697     self.key_cache[layer_idx] = k_out

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!"

Error of Evaluator (bottom half of file): https://gist.github.com/BlGene/607c7bee450e03835aa2bf0d2fd2959a

Upvotes: 0

Views: 8

Answers (0)

Related Questions