Reputation: 2279
I am trying to use DynamicCache with an input prompt. After caching the prompt I want to generate one output at a time (and append it to previous prompt) and generate again. I am running the following code (example toy code to reproduce the error):
import os
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
max_length = 6
prompt_cache = DynamicCache()
INITIAL_PROMPT = "I have a dream"
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt", padding=True).to("cuda")
# This is the common prompt cached, we need to run forward without grad to be able to copy
with torch.no_grad():
prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values.to("cuda")
input_text = INITIAL_PROMPT
# input_text = INITIAL_PROMPT + " "
responses = []
for _ in range(max_length):
new_inputs = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20,pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7)
output_text = tokenizer.decode(outputs[0])
print(output_text[len(input_text):])
input_text = output_text
print("#"*24)
The above code gives the following error:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[61], line 20
18 new_inputs = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda")
19 past_key_values = copy.deepcopy(prompt_cache)
---> 20 outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20,pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7)
21 output_text = tokenizer.decode(outputs[0])
22 print(output_text[len(input_text):])
File /media/data1/haque/.conda/envs/llm/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /media/data1/haque/.conda/envs/llm/lib/python3.12/site-packages/transformers/generation/utils.py:2252, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2244 input_ids, model_kwargs = self._expand_inputs_for_generation(
2245 input_ids=input_ids,
2246 expand_size=generation_config.num_return_sequences,
2247 is_encoder_decoder=self.config.is_encoder_decoder,
2248 **model_kwargs,
2249 )
2251 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2252 result = self._sample(
2253 input_ids,
2254 logits_processor=prepared_logits_processor,
2255 stopping_criteria=prepared_stopping_criteria,
2256 generation_config=generation_config,
2257 synced_gpus=synced_gpus,
2258 streamer=streamer,
2259 **model_kwargs,
2260 )
2262 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2263 # 11. prepare beam search scorer
2264 beam_scorer = BeamSearchScorer(
2265 batch_size=batch_size,
2266 num_beams=generation_config.num_beams,
(...)
2271 max_length=generation_config.max_length,
2272 )
File /media/data1/haque/.conda/envs/llm/lib/python3.12/site-packages/transformers/generation/utils.py:3244, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3239 is_prefill = True
3240 while self._has_unfinished_sequences(
3241 this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
3242 ):
3243 # prepare model inputs
-> 3244 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
3246 # prepare variable output controls (note: some models won't accept all output controls)
3247 model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
File /media/data1/haque/.conda/envs/llm/lib/python3.12/site-packages/transformers/generation/utils.py:388, in GenerationMixin.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs)
384 if past_key_values is not None:
385 model_inputs["past_key_values"] = past_key_values
386 if (
387 inputs_embeds is not None # Exception 1
--> 388 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
389 ):
390 input_ids = input_ids[:, -cache_position.shape[0] :]
391 elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
IndexError: index -1 is out of bounds for dimension 0 with size 0
But when I instead use input_text = INITIAL_PROMPT + " "
it works fine. What is the actually issue with cache_position and how can I resolve it without adding extra string (here " "
) in the input_text?
Upvotes: 0
Views: 21