AKSHAY JAIN
AKSHAY JAIN

Reputation: 23

Positional encoding for VIsion transformer

why the positional encoding is (1,patch,emb) size, it should be (batch_size,patch,emb) in general even in the pytorch github code https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py they are defining
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT

can anyone help me, what should I use as pos_encoding in my code

self.pos_embedding = nn.Parameter(torch.empty(batch_size, seq_length, hidden_dim).normal_(std=0.02))

is it correct?

Upvotes: 0

Views: 320

Answers (1)

Egor Prokopov
Egor Prokopov

Reputation: 16

Because you dont know the batch_size when initializing self.pos_embedding, so you should init this tensor as:

self.pos_embedding = nn.Parameter(
    torch.empty(1, num_patches + 1, hidden_dim).normal_(std=0.02)
) 
# (dont forget about the cls token)

PyTorch will take care of the tensors broadcasting in forward pass:

x = x + self.pos_embedding
# (batch_size, num_patches + 1, embedding_dim) + (1, num_patches + 1, embedding_dim) is ok

But it won't work with cls token. You should expand this tensor in forward:

cls_token = self.cls_token.expand(
    batch_size, -1, -1
)

Upvotes: 0

Related Questions