Reputation: 1
How can I increase the number of mini-batch using the Standard Updater class in Chainer substantially?
In case of PyTorch, I can increase the number of mini-batch substantially.
Question 1. In case of Chainer, Is it possible to increase the number of mini-batch substantially?
Question 2. In fact, I'm using the StandardUpdater class. Is it possible to increase the number of mini-batch using any of hyper parameters substantially? Or should I make my class that inherits from StandardUpdater class and change the implementation above?
I'm sorry if the questions have already been asked.
I hope any advice.
Upvotes: 0
Views: 64
Reputation: 1270
(Question seems quite old, but I stumbled upon it and wanted to share my solution to the question)
You would basically do it the same way you do it in PyTorch. Unfortunately, the StandardUpdater has neither a hyper-parameter that supports it nor an implementation for "mini-batch updates". But here is my implementation, how I did it (basically as you mentioned in your question: inherit from the StandardUpdater and re-implement the update_core method):
from chainer.training import StandardUpdater
from chainer.dataset import convert
class MiniBatchUpdater(StandardUpdater):
"""
The iterator outputs batches in mini-batch sizes. This updater
cummulates the gradients of these mini-batches until the
update_size is reached. Then a parameter update is performed
"""
def __init__(self, update_size=32, *args, **kwargs):
super(MiniBatchUpdater, self).__init__(*args, **kwargs)
self.update_size = update_size
self.iteration_counter = 0
def update_core(self):
optimizer = self.get_optimizer('main')
loss_func = self.loss_func or optimizer.target
it = self.get_iterator('main')
batch = it.next()
data = convert._call_converter(self.converter, batch, self.device)
use_cleargrads = getattr(optimizer, '_use_cleargrads', True)
if use_cleargrads and self.iteration_counter == 0:
optimizer.target.cleargrads()
self.iteration_counter += it.batch_size
loss = loss_func(*data)
loss.backward()
if self.iteration_counter >= self.update_size:
self.iteration_counter = 0
optimizer.update()
The implementation is quite old (I think for chainer 4 or 5), but I works for me with chainer 7.8 as well. One could update some lines to match the newer implementation of the update_core method, but as I said, it works for me. Hopefully it helps ;)
Upvotes: 0