Peter
Peter

Reputation: 9

PyTorch Vision Transformer - How Visualise Attention Layers

I am trying to extract the attention map for a PyTorch implementation of the Vision Transformer (ViT). however, I am having trouble understanding how to do this. I understand that doing this from within the EncoderBlock is probably not the best way.

Inside the EncoderBlock I managed to extract the attention weights per layer (using the default vit_b_16, e.g. 12 heads, patch size of 16x16). What I am struggling with is how to overlay this with the original image. Ideally, I would like is an output like this (https://github.com/huggingface/pytorch-image-models/discussions/1232):

Stage 1

Stage 2

Stage 3

class EncoderBlock(nn.Module):
def __init__(
    self,
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    dropout: float,
    attention_dropout: float,
    norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):

    super().__init__()
    self.num_heads = num_heads

    # Attention block
    self.ln_1 = norm_layer(hidden_dim)
    self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
    self.dropout = nn.Dropout(dropout)

    # MLP block
    self.ln_2 = norm_layer(hidden_dim)
    self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

def forward(self, input: torch.Tensor):
    torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
    x = self.ln_1(input)
    x, weights = self.self_attention(x, x, x, need_weights=True)
    x = self.dropout(x)
    x = x + input
    y = self.ln_2(x)
    y = self.mlp(y)
    result = x + y
    plt.imshow(weights.squeeze().detach().cpu().numpy())
    plt.show()
    return x + y

If I output the weights as an image for each layer, I get a similar image to the first one above. However, I cannot seem to understand how to get the to the final overlayed image. I would appreciate any assistance.

My weights output

Upvotes: 0

Views: 748

Answers (0)

Related Questions