jdhao
jdhao

Reputation: 28449

how to calculate loss over a number of images and then back propagate the average loss and update network weight

I am doing a task where the batch size is 1, i.e, each batch contains only 1 image. So I have to do manual batching: when the the number of accumulated losses reach a number, average the loss and then do back propagation. My original code is:

real_batchsize = 200

for epoch in range(1, 5):
    net.train()

    total_loss = Variable(torch.zeros(1).cuda(), requires_grad=True)

    iter_count = 0
    for batch_idx, (input, target) in enumerate(train_loader):

        input, target = Variable(input.cuda()), Variable(target.cuda())
        output = net(input)

        loss = F.nll_loss(output, target)

        total_loss = total_loss + loss

        if batch_idx % real_batchsize == 0:
            iter_count += 1

            ave_loss = total_loss/real_batchsize
            ave_loss.backward()
            optimizer.step()

            if iter_count % 10 == 0:
                print("Epoch:{}, iteration:{}, loss:{}".format(epoch,
                                                           iter_count,
                                                           ave_loss.data[0]))
            total_loss.data.zero_() 
            optimizer.zero_grad()

This code will give the error message

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I have tried the following way,

First way (failed)

I read some post about this error message, but can not understand it fully. Change ave_loss.backward() to ave_loss.backward(retain_graph=True) prevent the error message, but the loss doesn't improve the soon becomes nan.

Second way (failed)

I also tried to change total_loss = total_loss + loss.data[0], this will also prevent the error message. But the loss are always the same. So there must be something wrong.

Third way (success)

Following the instruction in this post, for each image's loss,I divide the loss by real_batchsize and backprop it. When the number of input image reach the real_batchsize, I do one update of parameter using optimizer.step(). The loss is slowly decreasing as the training process goes. But the training speed is really slow, because we backprop for each image.

My question

What does the error message mean in my case? Also, why doesn't first way and second way work? How to write the code correctly so that we can backprop gradient every real_batchsize images and update gradient once so that the training speed a faster? I know my code is nearly correct, but I just do not know how to change it.

Upvotes: 2

Views: 3662

Answers (1)

cleros
cleros

Reputation: 4343

The problem you encounter here is related to how PyTorch accumulates gradients over different passes. (see here for another post on a similar question) So let's have a look at what happens when you have code of the following form:

loss_total = Variable(torch.zeros(1).cuda(), requires_grad=True)
for l in (loss_func(x1,y1), loss_func(x2, y2), loss_func(x3, y3), loss_func(x4, y4)):
    loss_total = loss_total + l
    loss_total.backward()

Here, we do a backward pass when loss_total has the following values over the different iterations:

total_loss = loss(x1, y1)
total_loss = loss(x1, y1) + loss(x2, y2)
total_loss = loss(x1, y1) + loss(x2, y2) + loss(x3, y3)
total_loss = loss(x1, y1) + loss(x2, y2) + loss(x3, y3) + loss(x4, y4)

so when you call .backward() on total_loss each time, you actually call .backward() on loss(x1, y1) four times! (and on loss(x2, y2) three times, etc).

Combine that with what is discussed in the other post, namely that to optimize memory usage, PyTorch will free the graph attached to a Variable when calling .backward() (and thereby destroying the gradients connecting x1 to y1, x2 to y2, etc), you can see what the error message means - you try to do backward passes over a loss for several times, but the underlying graph was freed up after the first pass. (unless to specify retain_graph=True, of course)

As for the specific variations you have tried: First way: here, you will accumulate (i.e. sum up - again, see the other post) gradients forever, with them (potentially) adding up to inf. Second way: here, you convert loss to a tensor by doing loss.data, removing the Variable wrapper, and thereby deleting the gradient information (since only Variables hold gradients). Third way: here, you only do one pass through each xk, yk tuple, since you immediately do a backprop step, avoiding the above problem alltogether.

SOLUTION: I have not tested it, but from what I gather, the solution should be pretty straightforward: create a new total_loss object at the beginning of each batch, then sum all of the losses into that object, and then do one final backprop step at the end.

Upvotes: 3

Related Questions