ayush sharma
ayush sharma

Reputation: 1

transformer model predicting the same token during infrence but performing well during training

my transformer model is not working right.

Training loop :

for epoch in range(40):
 data_loader = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{20}", unit="batch")
for batch_idx, (en_batch, hi_batch) in enumerate(data_loader):
en_batch = en_batch.to('cuda').to(torch.long)
hi_batch = hi_batch.to('cuda').to(torch.long)
y_pred = model(en_batch, hi_batch)
loss = loss_fn(y_pred[:, 0:127, :].transpose(2,1), hi_batch[:, 1:128]).mean()
history.append(loss.item())
if batch_idx % 400 == 0:
clear_output(wait=False)
torch.save(history, 'history2.pth')
torch.save(losses, 'valLoss2.pth')
torch.save(model.state_dict(), 'model_weightsBPE2.pth')
model.eval()
val_loss = 0
with torch.no_grad():
for id , (en_batch, hi_batch) in enumerate(val_loader, 1):
en_batch, hi_batch = en_batch.to('cuda'), hi_batch.to('cuda')
y_pred = model(en_batch.to(torch.long), hi_batch[:, :-1].to(torch.long))
val_loss += loss_fn(y_pred[:, 0:127, :].transpose(2,1), hi_batch[:, 1:128].to(torch.long)).mean()
if id % 5 == 0:
break
val_loss /= 5
model.train()
print(f"Validation Loss: {val_loss}")
losses.append(val_loss.item())
if batch_idx % 100 == 0:
print("-"20, batch_idx, "-", epoch, "-", loss.item(), "-"20)
print("en : ", id_to_token(en_batch, "en"))
print("hi : ", id_to_token(hi_batch, "hi"))
print("out: ", id_to_token_M(y_pred, "hi"))
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()

Training output :

out: संघ के प्रदेशाध्यक्ष बृज मो गाुप्ता ने महारे्ल इेंदिर क के शबिदिन करे विद्रा।हन्थिया। कोप माया।

Inference loop :

def inference_loop( input_seq, max_output_length=128):
 input_seq = input_seq.unsqueeze(0) # Add batch dimension
current_token = torch.tensor([[1]], device=input_seq.device)
output_seq = []
for * in range(max*output_length):
predictions = model(input_seq, current_token)
next_token = predictions[:, -1, :].argmax(dim=-1)
current_token = torch.cat([current_token, next_token.unsqueeze(0)], dim=1)
if next_token.item() == 2:
break
output_seq.append(next_token.item())
return output_seq

Inference Output :

' क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क क'


how can i fix this issue, both training and validation loss goes down but during inference it starts to predict the same token again and again.

Upvotes: 0

Views: 82

Answers (0)

Related Questions