bindas
bindas

Reputation: 53

Wandb training kills kernel in jupyter lab

In my jupyter I can train my model on batch_size=8, but when I use wandb always after 9 iterations the process is killed and kernel restarts. What's more weird is that the same code worked on colab, but with my GPU (RTX 3080) I can never finish the process.

Does anyone have any idea how to overcome this issue?

Edit: I noticed that the kernel dies every time it tries to log the gradients to wandb. Can this be solved?

Code with wandb:

def train_batch(images, labels, model, optimizer, criterion):
    images, labels = images.to(device), labels.to(device)
    
    # Forward pass ➡
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()

    # Step with optimizer
    optimizer.step()
    
    size = images.size(0)
    del images, labels
    return loss, size

from loss import YoloLoss

# train the model
def train(model, train_dl, criterion, optimizer, config, is_one_batch):
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, criterion, log="all", log_freq=10)

    example_ct = 0  # number of examples seen
    batch_ct = 0
    
    # enumerate epochs
    for epoch in range(config.epochs):
        running_loss = 0.0
        
        if not is_one_batch:
            for i, (inputs, _, targets) in enumerate(train_dl):
                loss, batch_size = train_batch(inputs, targets, model, optimizer, criterion)
                running_loss += loss.item() * batch_size
        else:
            # for one batch only
            loss, batch_size = train_batch(train_dl[0], train_dl[2], model, optimizer, criterion)
            running_loss += loss.item() * batch_size
            
        epoch_loss = running_loss / len(train_dl)
#         loss_values.append(epoch_loss)
        wandb.log({"epoch": epoch, "avg_batch_loss": epoch_loss})
#         wandb.log({"epoch": epoch, "loss": loss}, step=example_ct)
        print("Average epoch loss {}".format(epoch_loss))
def make(config, is_one_batch, data_predefined=True):
    optimizers = {
        "Adam":torch.optim.Adam,
        "SGD":torch.optim.SGD
    }
    
    if data_predefined:
        train_dl, test_dl = train_dl_predef, test_dl_predef
    else:
        train_dl, test_dl = dataset.prepare_data()
        
    if is_one_batch:
        train_dl = next(iter(train_dl))
        test_dl = train_dl
    
    # Make the model
    model = architecture.darknet(config.batch_norm)
    model.to(device)

    # Make the loss and optimizer
    criterion = YoloLoss()
    optimizer = optimizers[config.optimizer](
        model.parameters(), 
        lr=config.learning_rate,
        momentum=config.momentum
    )
    
    return model, train_dl, test_dl, criterion, optimizer
        
def model_pipeline(hyp, is_one_batch=False, device=device):
    with wandb.init(project="YOLO-recreated", entity="bindas1", config=hyp):
        config = wandb.config
        
        # make the model, data, and optimization problem
        model, train_dl, test_dl, criterion, optimizer = make(config, is_one_batch)
        
        # and use them to train the model
        train(model, train_dl, criterion, optimizer, config, is_one_batch)
        
    return model

Code without wandb:

def train_model(train_dl, model, is_one_batch=False):
    # define the optimization
    criterion = YoloLoss()
    optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    
    # for loss plotting
    loss_values = []
    
    # enumerate epochs
    for epoch in tqdm(range(EPOCHS)):
        if epoch % 10 == 0:
            print(epoch)
        running_loss = 0.0
        
        if not is_one_batch:
        # enumerate mini batches
            for i, (inputs, _, targets) in enumerate(train_dl):
                inputs = inputs.to(device)
                targets = targets.to(device)
                # clear the gradients
                optimizer.zero_grad()
                # compute the model output
                yhat = model(inputs)
                # calculate loss
                loss = criterion(yhat, targets)
                # credit assignment
                loss.backward()
#                 print(loss)
                running_loss =+ loss.item() * inputs.size(0)
                # update model weights
                optimizer.step()
        else:
            # for one batch only
            with torch.autograd.detect_anomaly():
                inputs, targets = train_dl[0].to(device), train_dl[2].to(device)
                optimizer.zero_grad()
                # compute the model output
                yhat = model(inputs)
                # calculate loss
                loss = criterion(yhat, targets)
                # credit assignment
                loss.backward()
                print(loss)
                running_loss =+ loss.item() * inputs.size(0)
                # update model weights
                optimizer.step()
        loss_values.append(running_loss / len(train_dl))
    
    plot_loss(loss_values)

model = architecture.darknet()
model.to(device)
optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
train_dl_main, test_dl_main = train_dl_predef, test_dl_predef
one_batch = next(iter(train_dl_main))
train_model_wandb(one_batch, model, is_one_batch=True)

Upvotes: 0

Views: 672

Answers (1)

morganmcg
morganmcg

Reputation: 573

Hmm, strange, so in your edit you're saying that it works ok if you remove wandb.watch?

To double check, have you tried the original code while on the latest version of wandb (0.12.7)?

Upvotes: 1

Related Questions