Deshwal
Deshwal

Reputation: 4152

Is it a good idea to Multiply loss().item by batch_size to get the loss of a batch when batch size is not a factor of train_size?

Suppose we have problem where we have 100 images and a batch size of 15. We have 15 images in all of out batches except our last batch which contains 10 images.

Suppose we have network training as:

network = Network()
optimizer = optim.Adam(network.parameters(),lr=0.001)

for epoch in range(5):

    total_loss = 0

    train_loader = torch.utils.data.DataLoader(train_set,batch_size=15) 

    for batch in train_loader: 
        images,labels = batch

        pred = network(images)
        loss = F.cross_entropy(pred,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss+= loss.item()*15

Is'nt the last batch always supposed to give us a increased value of loss because we will be multiplying by 15 where we were supposed to multiply by 10 in the last batch? Should't it be total_loss+= loss.item()*len(images) in place of 15 or batch_size??

Can we use

for every epoch:
    for every batch:
        loss = F.cross_entropy(pred,labels,reduction='sum')
        total_loss+=loss.item()

    avg_loss_per_epoch = (total_loss/len(train_set))      

can someone please explain that multiplying by batch_size a good idea and how am I wrong?

Upvotes: 5

Views: 3797

Answers (1)

Harshit Kumar
Harshit Kumar

Reputation: 12837

Yes, you're right. Usually, for running loss the term

total_loss+= loss.item()*15

is written instead as (as done in transfer learning tutorial)

total_loss+= loss.item()*images.size(0)

where images.size(0) gives the current batch size. Thus, it'll give 10 (in your case) instead of hard-coded 15 for the last batch. loss.item()*len(images) is also correct!

In your second example, since you're using reduction='sum', the loss won't be divided by the batch size as it's done by default (because, by default, the reduction='mean' i.e. losses are averaged across observations for each minibatch).

Upvotes: 4

Related Questions