Reputation: 9
I am trying to extract the attention map for a PyTorch implementation of the Vision Transformer (ViT). however, I am having trouble understanding how to do this. I understand that doing this from within the EncoderBlock is probably not the best way.
Inside the EncoderBlock I managed to extract the attention weights per layer (using the default vit_b_16, e.g. 12 heads, patch size of 16x16). What I am struggling with is how to overlay this with the original image. Ideally, I would like is an output like this (https://github.com/huggingface/pytorch-image-models/discussions/1232):
class EncoderBlock(nn.Module):
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, weights = self.self_attention(x, x, x, need_weights=True)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
result = x + y
plt.imshow(weights.squeeze().detach().cpu().numpy())
plt.show()
return x + y
If I output the weights as an image for each layer, I get a similar image to the first one above. However, I cannot seem to understand how to get the to the final overlayed image. I would appreciate any assistance.
Upvotes: 0
Views: 748