Reputation: 28449
I am doing a task where the batch size is 1, i.e, each batch contains only 1 image. So I have to do manual batching: when the the number of accumulated losses reach a number, average the loss and then do back propagation. My original code is:
real_batchsize = 200
for epoch in range(1, 5):
net.train()
total_loss = Variable(torch.zeros(1).cuda(), requires_grad=True)
iter_count = 0
for batch_idx, (input, target) in enumerate(train_loader):
input, target = Variable(input.cuda()), Variable(target.cuda())
output = net(input)
loss = F.nll_loss(output, target)
total_loss = total_loss + loss
if batch_idx % real_batchsize == 0:
iter_count += 1
ave_loss = total_loss/real_batchsize
ave_loss.backward()
optimizer.step()
if iter_count % 10 == 0:
print("Epoch:{}, iteration:{}, loss:{}".format(epoch,
iter_count,
ave_loss.data[0]))
total_loss.data.zero_()
optimizer.zero_grad()
This code will give the error message
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I have tried the following way,
I read some post about this error message, but can not understand it fully. Change ave_loss.backward()
to ave_loss.backward(retain_graph=True)
prevent the error message, but the loss doesn't improve the soon becomes nan
.
I also tried to change total_loss = total_loss + loss.data[0]
, this will also prevent the error message. But the loss are always the same. So there must be something wrong.
Following the instruction in this post, for each image's loss,I divide the loss by real_batchsize
and backprop it. When the number of input image reach the real_batchsize
, I do one update of parameter using optimizer.step()
. The loss is slowly decreasing as the training process goes. But the training speed is really slow, because we backprop for each image.
What does the error message mean in my case? Also, why doesn't first way and second way work? How to write the code correctly so that we can backprop gradient every real_batchsize
images and update gradient once so that the training speed a faster? I know my code is nearly correct, but I just do not know how to change it.
Upvotes: 2
Views: 3662
Reputation: 4343
The problem you encounter here is related to how PyTorch accumulates gradients over different passes. (see here for another post on a similar question) So let's have a look at what happens when you have code of the following form:
loss_total = Variable(torch.zeros(1).cuda(), requires_grad=True)
for l in (loss_func(x1,y1), loss_func(x2, y2), loss_func(x3, y3), loss_func(x4, y4)):
loss_total = loss_total + l
loss_total.backward()
Here, we do a backward pass when loss_total
has the following values over the different iterations:
total_loss = loss(x1, y1)
total_loss = loss(x1, y1) + loss(x2, y2)
total_loss = loss(x1, y1) + loss(x2, y2) + loss(x3, y3)
total_loss = loss(x1, y1) + loss(x2, y2) + loss(x3, y3) + loss(x4, y4)
so when you call .backward()
on total_loss
each time, you actually call .backward()
on loss(x1, y1)
four times! (and on loss(x2, y2)
three times, etc).
Combine that with what is discussed in the other post, namely that to optimize memory usage, PyTorch will free the graph attached to a Variable when calling .backward()
(and thereby destroying the gradients connecting x1
to y1
, x2
to y2
, etc), you can see what the error message means - you try to do backward passes over a loss for several times, but the underlying graph was freed up after the first pass. (unless to specify retain_graph=True
, of course)
As for the specific variations you have tried:
First way: here, you will accumulate (i.e. sum up - again, see the other post) gradients forever, with them (potentially) adding up to inf
.
Second way: here, you convert loss
to a tensor by doing loss.data
, removing the Variable
wrapper, and thereby deleting the gradient information (since only Variables hold gradients).
Third way: here, you only do one pass through each xk, yk
tuple, since you immediately do a backprop step, avoiding the above problem alltogether.
SOLUTION: I have not tested it, but from what I gather, the solution should be pretty straightforward: create a new total_loss
object at the beginning of each batch, then sum all of the losses into that object, and then do one final backprop step at the end.
Upvotes: 3