Ted Wang
Ted Wang

Reputation: 11

Is gemma_pytorch next token generation code setting the input_token_positions a wrong way?

In gemma_torch repo (https://github.com/google/gemma_pytorch/tree/3294a89203f6227dc828b136564d3cb23bc6d115). I find that it only takes previous token when generating next token from the second one (https://github.com/google/gemma_pytorch/blob/3294a89203f6227dc828b136564d3cb23bc6d115/gemma/model.py#L659).

        for i in range(max_seq_len - min_prompt_len):
            next_token_ids, _ = self(
                input_token_ids=input_token_ids_tensor,
                input_positions=input_positions_tensor,
                kv_write_indices=None,
                kv_caches=kv_caches,
                mask=curr_mask_tensor,
                output_positions=output_positions_tensor,
                temperatures=temperatures_tensor,
                top_ps=top_ps_tensor,
                top_ks=top_ks_tensor,
            )

            curr_prompt_mask = prompt_mask_tensor.index_select(
                1, output_index).squeeze(dim=1)
            curr_token_ids = token_ids_tensor.index_select(
                1, output_index).squeeze(dim=1)
            output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
                                           next_token_ids).unsqueeze(dim=1)
            token_ids_tensor.index_copy_(1, output_index, output_token_ids)

            input_token_ids_tensor = output_token_ids
            input_positions_tensor = output_index.unsqueeze(dim=-1)
            curr_mask_tensor = mask_tensor.index_select(2,
                                                        input_positions_tensor)
            output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
                device)
            output_index = output_index + 1

It assigns input_positions_tensor to output_index which is one single token. Isn't this lost all the previous sentence context? From my point of view, this is equivalent of predicting "what's the next token when a token is at the given position". No previous context is provided at all. The correct implementation should be appending the current predicted token to previous context, which is input_positions_tensor = torch.concat(input_positions_tensor, output_index)). Am I wrong? Could anyone help me explain what's happening here?

Explain how decoder only casual LLM do the full sentence generation. And explain whether the code in gemma 2 torch repo is correct.

Upvotes: 1

Views: 23

Answers (0)

Related Questions