Reputation: 17
I'm loading a PaliGemma2 model google/paligemma2-3b-pt-224
and trying to fine-tune using Trainer/Seq2SeqTrainer. If I add evaluation, this fails. After doing some digging, I found that this only happens if the model is in evaluate mode.
batch = [valid_dataset[i] for i in range(8)]
inputs = collate_fn(batch)
#generate_ids = model.generate(**inputs, max_length=286+30)
trainer.model.train()
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("works")
trainer.model.train(False)
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("fails.")
I've worked around it by mokey-patching compute_loss_context_manager as follows:
orig_context_manager = trainer.compute_loss_context_manager
class TempTrainContext(object):
def __init__(self, trainer):
self.trainer = trainer
self.orig_context_manager = trainer.compute_loss_context_manager
def __enter__(self):
self.orig_context_inst = self.orig_context_manager()
self.orig_context_inst.__enter__()
self.training_enter = self.trainer.model.training
self.trainer.model.train()
def __exit__(self, type, value, traceback):
self.trainer.model.train(self.training_enter)
self.orig_context_inst.__exit__(type, value, traceback)
def __call__(self):
return self
trainer.compute_loss_context_manager = TempTrainContext(trainer)
(Bonus question: Is this safe to do, or will I train on the test set?)
My versions are:
Python Version: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0]
Torch Version: 2.5.1+cu124
CUDA Available: True
CUDA Device Count: 2
GPU Name: NVIDIA GeForce RTX 3090
Transformers Version: 4.48.1
Tokenizers Version: 0.21.0
Accelerate Version: 1.3.0
Error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[13], line 8
6 print("works")
7 trainer.model.train(False)
----> 8 trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
9 print("fails.")
12 orig_context_manager = trainer.compute_loss_context_manager
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3729 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3730 inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
3732 # Save past state if it exists
3733 # TODO: this needs to be fixed and made cleaner later.
3734 if self.args.past_index >= 0:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
527 causal_mask = self._update_causal_mask(
528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
529 )
--> 530 outputs = self.language_model(
531 attention_mask=causal_mask,
532 position_ids=position_ids,
533 past_key_values=past_key_values,
534 inputs_embeds=inputs_embeds,
535 use_cache=use_cache,
536 output_attentions=output_attentions,
537 output_hidden_states=output_hidden_states,
538 return_dict=return_dict,
539 cache_position=cache_position,
540 num_logits_to_keep=num_logits_to_keep,
541 )
543 logits = outputs.logits
544 loss = None
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
843 input_ids=input_ids,
844 attention_mask=attention_mask,
845 position_ids=position_ids,
846 past_key_values=past_key_values,
847 inputs_embeds=inputs_embeds,
848 use_cache=use_cache,
849 output_attentions=output_attentions,
850 output_hidden_states=output_hidden_states,
851 return_dict=return_dict,
852 cache_position=cache_position,
853 )
855 hidden_states = outputs[0]
856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
617 layer_outputs = self._gradient_checkpointing_func(
618 decoder_layer.__call__,
619 hidden_states,
(...)
626 cache_position,
627 )
628 else:
--> 629 layer_outputs = decoder_layer(
630 hidden_states,
631 position_embeddings=position_embeddings,
632 attention_mask=causal_mask,
633 position_ids=position_ids,
634 past_key_value=past_key_values,
635 output_attentions=output_attentions,
636 use_cache=use_cache,
637 cache_position=cache_position,
638 **flash_attn_kwargs,
639 )
641 hidden_states = layer_outputs[0]
643 if output_attentions:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
296 hidden_states = self.input_layernorm(hidden_states)
298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
300 hidden_states=hidden_states,
301 position_embeddings=position_embeddings,
302 attention_mask=attention_mask,
303 position_ids=position_ids,
304 past_key_value=past_key_value,
305 output_attentions=output_attentions,
306 use_cache=use_cache,
307 cache_position=cache_position,
308 )
309 hidden_states = self.post_attention_layernorm(hidden_states)
310 hidden_states = residual + hidden_states
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
221 if past_key_value is not None:
222 # sin and cos are specific to RoPE models; cache_position needed for the static cache
223 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
226 attention_interface: Callable = eager_attention_forward
227 if self.config._attn_implementation != "eager":
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
1714 else:
1715 update_fn = self._static_update
-> 1717 return update_fn(
1718 cache_position,
1719 layer_idx,
1720 key_states,
1721 value_states,
1722 k_out,
1723 v_out,
1724 k_out.shape[2],
1725 )
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694 k_out[:, :, cache_position] = key_states
1695 v_out[:, :, cache_position] = value_states
1697 self.key_cache[layer_idx] = k_out
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!"
Error of Evaluator (bottom half of file): https://gist.github.com/BlGene/607c7bee450e03835aa2bf0d2fd2959a
Upvotes: 0
Views: 8