takeshi0212
takeshi0212

Reputation: 1

How can I increase the number of mini-batch using the Standard Updater class in Chainer substantially?

How can I increase the number of mini-batch using the Standard Updater class in Chainer substantially?

In case of PyTorch, I can increase the number of mini-batch substantially.

Question 1. In case of Chainer, Is it possible to increase the number of mini-batch substantially?

Question 2. In fact, I'm using the StandardUpdater class. Is it possible to increase the number of mini-batch using any of hyper parameters substantially? Or should I make my class that inherits from StandardUpdater class and change the implementation above?

I'm sorry if the questions have already been asked.

I hope any advice.

Upvotes: 0

Views: 64

Answers (1)

DiKorsch
DiKorsch

Reputation: 1270

(Question seems quite old, but I stumbled upon it and wanted to share my solution to the question)

You would basically do it the same way you do it in PyTorch. Unfortunately, the StandardUpdater has neither a hyper-parameter that supports it nor an implementation for "mini-batch updates". But here is my implementation, how I did it (basically as you mentioned in your question: inherit from the StandardUpdater and re-implement the update_core method):

from chainer.training import StandardUpdater
from chainer.dataset import convert

class MiniBatchUpdater(StandardUpdater):
    """
        The iterator outputs batches in mini-batch sizes. This updater
        cummulates the gradients of these mini-batches until the
        update_size is reached. Then a parameter update is performed
    """
    def __init__(self, update_size=32, *args, **kwargs):
        super(MiniBatchUpdater, self).__init__(*args, **kwargs)
        self.update_size = update_size
        self.iteration_counter = 0

    def update_core(self):
        optimizer = self.get_optimizer('main')
        loss_func = self.loss_func or optimizer.target
        it = self.get_iterator('main')

        batch = it.next()
        data = convert._call_converter(self.converter, batch, self.device)

        use_cleargrads = getattr(optimizer, '_use_cleargrads', True)
        if use_cleargrads and self.iteration_counter == 0:
            optimizer.target.cleargrads()

        self.iteration_counter += it.batch_size
        loss = loss_func(*data)
        loss.backward()

        if self.iteration_counter >= self.update_size:
            self.iteration_counter = 0
            optimizer.update()

The implementation is quite old (I think for chainer 4 or 5), but I works for me with chainer 7.8 as well. One could update some lines to match the newer implementation of the update_core method, but as I said, it works for me. Hopefully it helps ;)

Upvotes: 0

Related Questions