Reputation: 13103
I'm trying to write a parallel data loader for deep learning in Rust. The task is to write an iterator that under the hood does the following
B
and "collates" them - this generally means just concatenating the arrays - moderately compute heavyStep 1 can be both IO and compute bound, depending on network latency, size of files and complexity of preprocessing. It has to be run in parallel by many workers. Step 2 should be off the main thread but likely doesn't need a pool of workers. Step 3 happens on main thread (exposed to Python).
The reason I write it in Rust is that Python offers two options: pure Python implementation shipped with PyTorch, based on multiprocessing
, which is somewhat slow but very flexible (arbitrary user-defined data preprocessing and batching) and C++ implementation shipped with Tensorflow, which is assembled by the user from a set of predefined primitives. The latter is substantially faster but too restrictive for the kinds of data processing I wish to do. I expect that Rust will give me the speed of Tensorflow with flexibility of arbitrary code as in PyTorch.
My question is purely about the way to implement parallelism. The ideal setup is to have N workers for step 1) -> channel -> worker for step 2) -> channel -> step 3. Because the iterator object may be dropped at any time, there is a strict requirement to be able to terminate the whole scheme after Drop
. On the other hand, there is the flexibility of loading the files in an arbitrary order: for example if the batch size B == 16
and max_n_threads == 32
, it is perfectly fine to start 32 workers and yield the first batch containing the 16 examples which happen to return first. This can be exploited for speed.
My naive implementation creates the DataLoader
in 3 steps:
n_working: Arc<AtomicUsize>
to control the number of worker threads active and should_shutdown: Arc<AtomicBool>
to signal shutdown (when Drop
is called)n_working < max_n_threads
and keeps spawning worker threads which terminate on should_shutdown
, otherwise fetch a single example, send it down the worker->batcher channel and decrement n_working
B
objects concatenates them into a batch and sends down the batcher->yielder channel#[pyclass]
struct DataLoader {
collate_worker: Option<thread::JoinHandle<()>>,
example_worker: Option<thread::JoinHandle<()>>,
should_shut_down: Arc<AtomicBool>,
receiver: Receiver<Batch>,
length: usize,
}
impl DataLoader {
fn new(
dataset: Dataset,
batch_size: usize,
capacity: usize,
) -> Self {
let n_batches = dataset.len() / batch_size;
let max_n_threads = capacity * batch_size;
let (example_sender, collate_receiver) = bounded((batch_size - 1) * capacity);
let should_shut_down = Arc::new(AtomicBool::new(false));
let shutdown_flag = should_shut_down.clone();
let example_worker = thread::spawn(move || {
rayon::scope_fifo(|s| {
let dataset = &dataset;
let n_working = Arc::new(AtomicUsize::new(0));
let mut current_index = 0;
while current_index < n_batches * batch_size {
if n_working.load(Ordering::Relaxed) == max_n_threads {
continue;
}
if shutdown_flag.load(Ordering::Relaxed) {
break;
}
let index = current_index.clone();
let sender = example_sender.clone();
let counter = n_working.clone();
let shutdown_flag = shutdown_flag.clone();
s.spawn_fifo(move |_s| {
let example = dataset.get_example(index);
if !shutdown_flag.load(Ordering::Relaxed) {
_ = sender.send(example);
} // if we should shut down, skip sending
counter.fetch_sub(1, Ordering::Relaxed);
});
current_index += 1;
n_working.fetch_add(1, Ordering::Relaxed);
};
});
});
let (batch_sender, final_receiver) = bounded(capacity);
let shutdown_flag = should_shut_down.clone();
let collate_worker = thread::spawn(move || {
'outer: loop {
let mut batch = vec![];
for _ in 0..batch_size {
if let Ok(example) = collate_receiver.recv() {
batch.push(example);
} else {
break 'outer;
}
};
let collated = collate(batch);
if shutdown_flag.load(Ordering::Relaxed) {
break; // skip sending
}
_ = batch_sender.send(collated);
};
});
Self {
collate_worker: Some(collate_worker),
example_worker: Some(example_worker),
should_shut_down: should_shut_down,
receiver: final_receiver,
length: n_batches,
}
}
}
#[pymethods]
impl DataLoader {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { slf }
fn __next__(&mut self) -> Option<Batch> {
self.receiver.recv().ok()
}
fn __len__(&self) -> usize {
self.length
}
}
impl Drop for DataLoader {
fn drop(&mut self) {
self.should_shut_down.store(true, Ordering::Relaxed);
if self.collate_worker.take().unwrap().join().is_err() {
println!("Panic in collate worker");
};
if self.example_worker.take().unwrap().join().is_err() {
println!("Panic in example_worker");
};
println!("dropped the dataloader");
}
}
This implementation works and roughly matches the performance of PyTorch but provides no significant speedup. I don't know where to look for improvements, but I imagine it would help to have the thing load-balance automatically in a work-stealing way and to flexibly spawn workers depending on the proportion of IO and compute time. I am also expecting performance issues due to the spinning pool manager and likely corner cases in my handling of Drop
.
My question is how to best approach the problem. I am generally unsure if this should be tackled with parallel crates like rayon
, async crates like tokio
, or a mix of both. I also have the hunch my implementation could be much simpler with the correct use of their combinators/higher order APIs. I tried with rayon
but I couldn't get a solution which doesn't wastefully enforce the original sequential returning order and respects the Drop
requirement.
Upvotes: 2
Views: 389
Reputation: 15002
Okay I think I've figured out a solution for you that uses rayon parallel iterators.
The trick is to use Results
in the rayon iterators, and return Err
if the cancellation flag is set.
I first created a utility type to create a cancellable thread in which you can execute rayon iterators. You use it by passing in the thread closure which takes the atomic cancellation token as a parameter. Then you have to check if the cancellation token is true
, and if so, exit early.
use std::sync::Arc;
use std::sync::atomic::{Ordering, AtomicBool};
use std::thread::JoinHandle;
fn collate(batch: &[Computed]) -> Batch {
batch.iter().map(|&x| i128::from(x)).sum()
}
#[derive(Debug)]
struct Cancelled;
struct CancellableThread<Output: Send + 'static> {
cancel_token: Arc<AtomicBool>,
thread: Option<JoinHandle<Result<Output, Cancelled>>>,
}
impl<Output: Send + 'static> CancellableThread<Output> {
fn new<F: FnOnce(Arc<AtomicBool>) -> Result<Output, Cancelled> + Send + 'static>(init: F) -> Self {
let cancel_token = Arc::new(AtomicBool::new(false));
let thread_cancel_token = Arc::clone(&cancel_token);
CancellableThread {
thread: Some(std::thread::spawn(move || init(thread_cancel_token))),
cancel_token,
}
}
fn output(mut self) -> Output {
self.thread.take().unwrap().join().unwrap().unwrap()
}
}
impl<Output: Send + 'static> Drop for CancellableThread<Output> {
fn drop(&mut self) {
self.cancel_token.store(true, Ordering::Relaxed);
if let Some(thread) = self.thread.take() {
let _ = thread.join().unwrap();
}
}
}
I found it useful to create a closure that returns a Result<(), Cancelled>
so I could use the try operator (?
) to exit early.
CancellableThread::new(move |cancel_token| {
let cancelled = || if cancel_token.load(Ordering::Relaxed) {
Err(Cancelled)
} else {
Ok(())
};
loop {
// was the thread dropped?
// if so, stop what we're doing
cancelled?;
// do stuff and
// eventually return a result
}
});
I then used that CancellableThread
abstraction in the DataLoader
. No need to create a special Drop
impl for it, because by default, it will call drop
on each field anyways, which will handle the cancellation.
type Data = Vec<u8>;
type Dataset = Vec<Data>;
type Computed = u64;
type Batch = i128;
use rayon::prelude::*;
use crossbeam::channel::{unbounded, Receiver};
struct DataLoader {
example_worker: CancellableThread<()>,
collate_worker: CancellableThread<()>,
receiver: Receiver<Batch>,
length: usize,
}
I used unbounded
channels, as it was one less thing to bother about. It shouldn't be hard to switch to bounded
ones instead.
impl DataLoader {
fn new(dataset: Dataset, batch_size: usize) -> Self {
let (example_sender, collate_receiver) = unbounded();
let (batch_sender, final_receiver) = unbounded();
I'm not sure if you can always guarantee that the number of items in your dataset will be a multiple of the batch_size
, so I decided to handle that explicitly.
let length = if dataset.len() % batch_size == 0 {
dataset.len() / batch_size
} else {
dataset.len() / batch_size + 1
};
I created the collating worker first, though that may not be necessary. As you can see, I had to duplicate a little bit to handle partial batches.
let collate_worker = CancellableThread::new(move |cancel_token| {
let cancelled = || if cancel_token.load(Ordering::Relaxed) {
Err(Cancelled)
} else {
Ok(())
};
'outer: loop {
let mut batch = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
cancelled()?;
if let Ok(data) = collate_receiver.recv() {
batch.push(data);
} else {
if !batch.is_empty() {
// handle the last batch, if there
// weren't enough items to fill it
let collated = collate(&batch);
cancelled()?;
batch_sender.send(collated).unwrap();
}
break 'outer;
}
}
let collated = collate(&batch);
cancelled()?;
batch_sender.send(collated).unwrap();
}
Ok(())
});
The example worker is where things are really made much simpler, because we can just use rayon parallel iterators. As you can see, we check for cancellation before each heavy computation.
let example_worker = CancellableThread::new(move |cancel_token| {
let cancelled = || if cancel_token.load(Ordering::Relaxed) {
Err(Cancelled)
} else {
Ok(())
};
let heavy_compute = |data: Data| -> Result<Computed, Cancelled> {
cancelled()?;
Ok(data.iter().map(|&x| u64::from(x)).product())
};
dataset
.into_par_iter()
.map(heavy_compute)
.try_for_each(|computed| {
example_sender.send(computed?).unwrap();
Ok(())
})
});
Then we just construct the DataLoader
. You can see the Python impl is identical:
DataLoader {
example_worker,
collate_worker,
receiver: final_receiver,
length,
}
}
}
// #[pymethods]
impl DataLoader {
fn __iter__(this: Self /* PyRef<Self> */) -> Self /* PyRef<Self> */ { this }
fn __next__(&mut self) -> Option<Batch> {
self.receiver.recv().ok()
}
fn __len__(&self) -> usize {
self.length
}
}
Upvotes: 1