Reputation: 25
I met a programming issue about class a function. it seems like I can not class it correctly. Can you please point out the issue? THANK YOU !
class NTXentLoss(nn.Module):
def __init__(self, temp=0.5):
super(NTXentLoss, self).__init__()
self.temp = temp
def forward(self, zi, zj):
batch_size = zi.shape[0]
z_proj = torch.cat((zi, zj), dim=0)
cos_sim = torch.nn.CosineSimilarity(dim=-1)
sim_mat = cos_sim(z_proj.unsqueeze(1), z_proj.unsqueeze(0))
sim_mat_scaled = torch.exp(sim_mat/self.temp)
r_diag = torch.diag(sim_mat_scaled, batch_size)
l_diag = torch.diag(sim_mat_scaled, -batch_size)
pos = torch.cat([r_diag, l_diag])
diag_mat = torch.exp(torch.ones(batch_size * 2)/self.temp).cuda()
logit = -torch.log(pos/(sim_mat_scaled.sum(1) - diag_mat))
loss = logit.mean()
return loss
sent_A = l2norm(recov_A, dim=1)
sent_emb_A = l2norm(imgs_A, dim=1)
sent_B = l2norm(recov_B, dim=1)
sent_emb_B = l2norm(imgs_B, dim=1)
G_cons = NTXentLoss(sent_A,sent_emb_A) + NTXentLoss(sent_B,sent_emb_B)
What's wrong with this, I just gave two positional arguments? or
G_cons = NTXentLoss.forward(sent_A,sent_emb_A) + NTXentLoss.forward(sent_B,sent_emb_B)
Upvotes: 0
Views: 896
Reputation: 383
You need to first initiate a NTXentLoss object before you can call it. For instance:
ntx = NTXentLoss()
G_cons = ntx(sent_A,sent_emb_A) + ntx(sent_B,sent_emb_B)
Upvotes: 2