Ant
Ant

Reputation: 1153

Is there a way to see what's going wrong with a training session in Pytorch?

I'm training a triplet convolution neural network in Jupyter.
When I execute the cell I just get the * symbol and nothing happens.

I'm not asking for help finding a problem with the code. I would just like to know if there is a troubleshooting possibility that might let me see what is happening.
There is probably something wrong with my data loader, or data format, or model. I will find it myself if Pytorch or somebody has a method to find a clue here. It is not giving me an error. It is just working ad infinitum on something wrong.

I saw a function called 'set_trace()' that could be typed into the block that is supposed to be able to give a clue about the problem. But after putting it in the for loop I get

NameError: name 'set_trace' is not defined


for batch_idx in range(1):
    for batch_idx, (data, target) in enumerate(triplet_train_loader):
        model.train()
        metrics = []
        losses = []
        total_loss = 0
        data = tuple(d.cuda() for d in data)
        optimizer.zero_grad()
        outputs = model(*data)

        loss_outputs = loss_fn(*outputs)
        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        set_trace()
        if batch_idx % log_interval == 0:
            message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data[0]), len(triplet_train_loader.dataset),
                100. * batch_idx / len(triplet_train_loader), np.mean(losses))
            for metric in metrics:
                message += '\t{}: {}'.format(metric.name(), metric.value())

            print(message)
            losses = []

"Describe expected and actual results"

LOVE FOR IT TO TRAIN. IN ACTUAL RESULT WORLD IT DOES NOT TRAIN.

Upvotes: 0

Views: 102

Answers (1)

artona
artona

Reputation: 1282

NameError: name 'set_trace' is not defined

You mean:

import pdb; pdb.set_trace()

Upvotes: 1

Related Questions