Justin
Justin

Reputation: 1

gradient accumulation stopping at 50%

training stopping at 50%

the original batch_size = 16, but I wanted to give accumulation = 2 so that I have a similar effect as when I used batch_size = 32.

The original training time lasted an hour, so I expected 2 hour training time with the gradient accumulation.

But the training ends at 50%, lasting an hour even with the gradient accumulation.

I don't know why it's stopping.. below is my code for training

def train_runner(model, train_dataset, valid_dataset , batch_size, num_train_epochs, learning_rate): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
model.train()
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size)
valid_dataloader = DataLoader(dataset = valid_dataset, batch_size = batch_size)

lowest_total_valid_loss = 9999.
step = 0
global_total_step = len(train_dataloader) * num_train_epochs
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0)
print("TRAIN START")
with tqdm(total=global_total_step, unit='step') as t:
    total = 0
    total_loss = 0
    for epoch in range(num_train_epochs):
        for iteration,batch in enumerate(train_dataloader):
            #optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)
            outputs = model(input_ids,
                         attention_mask=attention_mask,
                         start_positions=start_positions,
                         end_positions=end_positions)
            loss = outputs.loss
            (loss / ACCUMULATION).backward()

            step += 1
            if step % ACCUMULATION:
                continue

            clip_grad_norm_(model.parameters(), max_norm=1.)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            batch_loss = loss.item() * len(input_ids)
            total += len(input_ids)
            total_loss += batch_loss / ACCUMULATION
            global_total_step += 1
            t.set_postfix(loss="{:.6f}".format(total_loss / total), batch_loss="{:.6f}".format(batch_loss))
            t.update(1)
            
            del input_ids
            del attention_mask
            del start_positions
            del end_positions
            del outputs
            del loss

            ## validation ##
            if iteration != 0 and iteration % int(len(train_dataloader) / 10) == 0:
                total_valid_loss = 0
                for batch_val in valid_dataloader:
                    model.eval()
                    optimizer.zero_grad()

                    input_ids = batch_val['input_ids'].to(device)
                    attention_mask = batch_val['attention_mask'].to(device)
                    start_positions = batch_val['start_positions'].to(device)
                    end_positions = batch_val['end_positions'].to(device)
            
                    with torch.no_grad():
                        outputs = model(input_ids,
                                attention_mask=attention_mask,
                                start_positions=start_positions,
                                end_positions=end_positions)
                        loss = outputs.loss
                        total_valid_loss += loss.item()
                
                if total_valid_loss < lowest_total_valid_loss:
                    print(f"lowest_total_valid_loss: {total_valid_loss} epoch : {epoch} iteration : {iteration}")
                    torch.save(model.state_dict(),'./output_model_best')
                    lowest_total_valid_loss = total_valid_loss
            ## validation ##

#model.save_pretrained("./klue_output_model")
print("TRAIN END")

Upvotes: 0

Views: 133

Answers (1)

Daraan
Daraan

Reputation: 3780

for iteration,batch in enumerate(train_dataloader):    
    if step % ACCUMULATION:
       t.update(1) # add one update here as well.
       continue 
    ...
    t.update(1)

Half of the time you do not update the tqdm counter or set its value too high during initialization. So it can't go higher than 50%.

Upvotes: 0

Related Questions