Hellowhatsup
Hellowhatsup

Reputation: 21

Mistral-7B: Does later text affect the logits of previous tokens?

For transformers, I thought that there's an internal masking that prevents future tokens from affecting previous tokens logits. But consider the following code (I'm using Mistral-7b + torch.bfloat16):

### TEST
# Let's say there's some input, and I tokenized and saved it in "inp_rej"
# Its shape is [1, 192].

model1.eval()
new_inp = inp_rej[0, :173]

with torch.no_grad():
    new_out1 = model1.generate(new_inp.unsqueeze(0), temperature=0, max_length=256, return_dict_in_generate=True, output_scores=True)
    temp_out1 = model1(new_inp.unsqueeze(0))
    comp_out1 = model1(inp_rej)

a1 = torch.softmax(new_out1['scores'][0], dim=-1).max()
a2 = torch.softmax(temp_out1.logits[0][-1], dim=-1).max()
a3 = torch.softmax(comp_out1.logits[0, len(new_inp) - 1], dim=-1).max()

print(a1 - a2) # tensor(0., device='cuda:0'), OK.
print(a1 - a3) # tensor(0.0300, device='cuda:0'), Why?

Why does a1 - a3 lead to 0.03?

Moreover, when I do:

abs(temp_out1.logits[0][0] - comp_out1.logits[0][0]).mean()

This outputs tensor(0.0122, device='cuda:0').

Interestingly, this occurs more when I'm using longer sequence. Consider below code:

for N in [30, 60, 90, 120, 150]:
    new_inp_rej = inp_rej.clone()[0:N]
    model1.eval()
    new_inp = new_inp_rej[0, :N - 20]

    with torch.no_grad():
        new_out1 = model1.generate(new_inp.unsqueeze(0), temperature=0, max_length=256, return_dict_in_generate=True, output_scores=True)
        temp_out1 = model1(new_inp.unsqueeze(0))
        comp_out1 = model1(inp_rej)

    a1 = torch.softmax(new_out1['scores'][0], dim=-1).max()
    a2 = torch.softmax(temp_out1.logits[0][-1], dim=-1).max()
    a3 = torch.softmax(comp_out1.logits[0, len(new_inp) - 1], dim=-1).max()

    diff = (a1 - a3).item()
    if diff != 0:
        print(N, ":", diff)

# RESULT:
# 60 : 3.6954879760742188e-06
# 90 : 0.0025225281715393066
# 120 : 0.00031453371047973633

Upvotes: 2

Views: 199

Answers (0)

Related Questions