kli_nlpr
kli_nlpr

Reputation: 914

How to write a caffe python data layer with preload?

How to write an asyncronous data layer to preload batches while other processing is performed? Are there some example codes? Thanks

Upvotes: 4

Views: 514

Answers (1)

Shai
Shai

Reputation: 114796

There are several ways you can achieve what you want. I'll try and sketch one option here.

The overall view of the system is: you have n Loaders asynchronously loading data and feeding a queue. The layer then reads batch_size items from the queue and feed the net in the forward() function.

import caffe, multiprocessing

class Loader(multiprocessing.Process):
  def __init__(self, outq, *args, **kwargs):
    super(Loader, self).__init__()
    self.daemon = True
    self.outq = outq
    self.start()  # start working

  def run(self):
    while True:  # read and never stop at all!
      try:
        # do your magic here
        # assuming you load x,y pairs
        self.outq.put((x[None, ...], y[None, ...]))  # add singleton "batch" dimension
      except Exception as e:
        # handle errors?
        pass

 class MultiProcessInputLayer(caffe.Layer):
   def setup(self, bottom, top):
     # verify no bottoms, right number of tops etc.
     self.dataQ = multiprocessing.Queue()
     for _ in xrange(n):
       Loader(self.dataQ)  # start n Loaders
     # some other stuff here...

   def reshape(self, bottom, top):
     # reshape the inputs to the right sizes

   def forward(self, bottom, top):
     for i in xrange(batch_size):
       item = self.dataQ.get()
       top[0].data[i, ...] = item[0]
       top[1].data[i, ...] = item[1]

   def backward(self, top, propagate_down, bottom):
     pass  # no backward for data layer

Some tips and tricks I learned the hard way:
1. Use multiprocessing and not threading package because of the GIL.
2. Sometimes (e.g. if batch_size is very large) it will take very long for forward() to read item by item from the Queue to form each batch. In that case, you might add another multiprocessing.Process that will async read batch_size items from self.dataQ and write whole batches to self.batchQ. Then forward() will only wait for a single item from self.batchQ at each call.
3. Take care not to copy the data around too much. Working with large images/labels can make all these copying into a bottleneck.

Upvotes: 4

Related Questions