Reputation: 1
I'm learning how to integrate Flash Attention into my model to accelerate training. I'm testing the function to determine the best way to implement it. However, I've encountered an issue where Flash Attention produces different results for tokens that have identical embeddings. I'm not sure if I'm making a basic mistake or if there's something else at play.
Here's the code snippet I'm using:
import torch
from flash_attn.modules.mha import FlashSelfAttention
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fa_attn = FlashSelfAttention(deterministic=True)
fa_attn.eval()
# Assuming batch_size, seq_len, heads, dim = 1, 4, 1, 4
x = torch.tensor([[0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1]])
q = x.unsqueeze(0).unsqueeze(2)
k = q.clone()
v = q.clone()
qkv = torch.stack([q, k, v], dim=2).half().to(device)
output = fa_attn(qkv)
print(output)
result:
tensor([[[[0.1000, 0.1000, 0.1000, 0.1000]],
[[0.0757, 0.0757, 0.0757, 0.0757]],
[[0.1000, 0.1000, 0.1000, 0.1000]],
[[0.0757, 0.0757, 0.0757, 0.0757]]]], device='cuda:0', dtype=torch.float16)
Another one
x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1]])
q = x.unsqueeze(0).unsqueeze(2)
k = q.clone()
v = q.clone()
qkv = torch.stack([q, k, v], dim=2).half().to(device)
output = fa_attn(qkv)
output
result:
tensor([[[[ 0.1000, 0.1000, 0.1000, 0.1000]],
[[-0.5483, 0.5166, -0.5483, 0.5166]],
[[ 0.1000, 0.1000, 0.1000, 0.1000]]]], device='cuda:0',
dtype=torch.float16)
Thanks very much.
Upvotes: 0
Views: 90