Reputation: 155
I am using ViViT in my model. Although I moved the input and the my whole model to the cuda, the train process shows an error in the line of position embedding:
class ViViTBackbone(nn.Module):
""" Model-3 backbone of ViViT """
def __init__(self, t, h, w, patch_t, patch_h, patch_w, num_classes, dim, depth, heads, mlp_dim, dim_head=3,
channels=3, mode='tubelet', emb_dropout=0., dropout=0., model=3):
super().__init__()
assert t % patch_t == 0 and h % patch_h == 0 and w % patch_w == 0, "Video dimensions should be divisible by " \
"tubelet size "
self.T = t
self.H = h
self.W = w
self.channels = channels
self.t = patch_t
self.h = patch_h
self.w = patch_w
self.mode = mode
self.nt = self.T // self.t
self.nh = self.H // self.h
self.nw = self.W // self.w
tubelet_dim = self.t * self.h * self.w * channels
self.to_tubelet_embedding = nn.Sequential(
Rearrange('b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)', pt=self.t, ph=self.h, pw=self.w),
nn.Linear(tubelet_dim, dim)
)
# repeat same spatial position encoding temporally
self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.nh * self.nw, dim)).repeat(1, self.nt, 1, 1)
self.dropout = nn.Dropout(emb_dropout)
if model == 3:
self.transformer = FSATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
self.nt, self.nh, self.nw, dropout)
elif model == 4:
assert heads % 2 == 0, "Number of heads should be even"
self.transformer = FDATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
self.nt, self.nh, self.nw, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
""" x is a video: (b, C, T, H, W) """
tokens = self.to_tubelet_embedding(x)
tokens += self.pos_embedding #The error is because of this line
tokens = self.dropout(tokens)
x = self.transformer(tokens)
return x
I create the ViViT according to the following method inside my model class:
self.vivit_FSA_F_8 = ViViTBackbone(t=8, h=16, w=24, patch_t=1, patch_h=16, patch_w=24, num_classes=10, dim=128,
depth=6, heads=10, mlp_dim=8, model=3)
How can I fix that?
Upvotes: 0
Views: 59
Reputation: 131
There are multiple ways: Instead of creating parameters like:
self.T = t
do:
self.T = nn.Parameter(t)
then model.to(device) will push all the parameters to the correct device too.
An alternative is to use the device parameter whenever you create a tensor
some_tensor = torch.tensor(1.0,device=self.device)
or
some_tensor = torch.ones([3,4],device=self.device)
Upvotes: 2