Reputation: 11
In gemma_torch repo (https://github.com/google/gemma_pytorch/tree/3294a89203f6227dc828b136564d3cb23bc6d115). I find that it only takes previous token when generating next token from the second one (https://github.com/google/gemma_pytorch/blob/3294a89203f6227dc828b136564d3cb23bc6d115/gemma/model.py#L659).
for i in range(max_seq_len - min_prompt_len):
next_token_ids, _ = self(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2,
input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
device)
output_index = output_index + 1
It assigns input_positions_tensor
to output_index
which is one single token. Isn't this lost all the previous sentence context? From my point of view, this is equivalent of predicting "what's the next token when a token is at the given position". No previous context is provided at all.
The correct implementation should be appending the current predicted token to previous context, which is input_positions_tensor = torch.concat(input_positions_tensor, output_index))
.
Am I wrong? Could anyone help me explain what's happening here?
Explain how decoder only casual LLM do the full sentence generation. And explain whether the code in gemma 2 torch repo is correct.
Upvotes: 1
Views: 23