chungking youjun
chungking youjun

Reputation: 1

flash attention gives different result for tokens of identical embeddings?

I'm learning how to integrate Flash Attention into my model to accelerate training. I'm testing the function to determine the best way to implement it. However, I've encountered an issue where Flash Attention produces different results for tokens that have identical embeddings. I'm not sure if I'm making a basic mistake or if there's something else at play.

Here's the code snippet I'm using:

import torch
from flash_attn.modules.mha import FlashSelfAttention

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fa_attn = FlashSelfAttention(deterministic=True)
fa_attn.eval()

# Assuming batch_size, seq_len, heads, dim = 1, 4, 1, 4
x = torch.tensor([[0.1, 0.1, 0.1, 0.1],
                  [0.1, 0.1, 0.1, 0.1],
                  [0.1, 0.1, 0.1, 0.1],
                  [0.1, 0.1, 0.1, 0.1]])
q = x.unsqueeze(0).unsqueeze(2)
k = q.clone()
v = q.clone()
qkv = torch.stack([q, k, v], dim=2).half().to(device)
output = fa_attn(qkv)
print(output)

result:

tensor([[[[0.1000, 0.1000, 0.1000, 0.1000]],
         [[0.0757, 0.0757, 0.0757, 0.0757]],
         [[0.1000, 0.1000, 0.1000, 0.1000]],
         [[0.0757, 0.0757, 0.0757, 0.0757]]]], device='cuda:0', dtype=torch.float16)

Another one

x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1]])
q = x.unsqueeze(0).unsqueeze(2)
k = q.clone()
v = q.clone()
qkv = torch.stack([q, k, v], dim=2).half().to(device)
output = fa_attn(qkv)
output

result:

tensor([[[[ 0.1000,  0.1000,  0.1000,  0.1000]],

         [[-0.5483,  0.5166, -0.5483,  0.5166]],

         [[ 0.1000,  0.1000,  0.1000,  0.1000]]]], device='cuda:0',
       dtype=torch.float16)

Thanks very much.

Upvotes: 0

Views: 90

Answers (0)

Related Questions