Reputation: 528
I am using rayon's .par_bridge().for_each(|r| { do_work(r) } )
to run tasks in parallel for some iterator (specifically: Records from a bed file, but I don't think that matters). There could be up to ~700.000 tasks.
I want to print (stdout or to a file) the results of every call to do_work(), but do this printing only in the order of the original iterator. I could sort all output after all parallel jobs have been completed, but storing all results until the end will require much more memory. I could add .enumerate()
to have an index for each item and print out the first one when it is done, storing the rest until it is their turn, but I am not sure how to best implement such a system, or if it is the best solution at all. What would you suggest?
Upvotes: 1
Views: 1450
Reputation: 101
Here is a version adapted from @kevin-reid's.
As mentioned in my comment to his post, with his code, if a thread is slow, the others will keep on working anyway, resulting in more and more data being stored in the buffer. In a worst case scenario, all the data can get stored in that buffer but the first.
My adaptation will block a thread from sending its result until it's its turn. This means that at most, there will be one result per thread in memory (plus whatever is waiting in the sync_channel
).
The main drawback is the need for the Mutex
and notify_all
, which can have a significant cost when the work can be done quickly.
It's also possible to combine both answers to allow the threads to run ahead but within reason, but I leave that as an exercise 😜.
use std::{
sync::{mpsc, Arc, Condvar, Mutex},
time::Duration,
};
use rand::Rng;
use rayon::prelude::{ParallelBridge, ParallelIterator};
fn main() {
let data_source = (0..500u32).rev();
// Channel with enough capacity to hold an item from each thread
let (tx, rx) = mpsc::sync_channel(std::thread::available_parallelism().map_or(8, |n| n.get()));
let gate = Arc::new((Mutex::new(0), Condvar::new()));
rayon::scope(|s| {
s.spawn(move |_| {
data_source
.enumerate()
.par_bridge()
.map(|(i, value)| {
// pretend to do some work
// For worst case scenario, replace with:
// if i == 0 { std::thread::sleep(Duration::from_millis(1000)); }
std::thread::sleep(Duration::from_millis(
1000 + rand::thread_rng().gen_range(0..10),
));
(i, value)
})
.for_each_with((tx, gate), |(tx, gate), (i, value)| {
let (lock, cond) = &**gate;
{
let mut guard = cond.wait_while(lock.lock().unwrap(), |v| *v < i).unwrap();
let _ = tx.send((i, value));
*guard = i + 1;
}
cond.notify_all();
});
});
recover_order(rx, emit);
});
}
fn emit(item: u32) {
println!("done with {item}");
}
fn recover_order<T>(rx: mpsc::Receiver<(usize, T)>, mut op: impl FnMut(T)) {
let mut next_index: usize = 0;
for (i, value) in rx {
if i != next_index {
// Item is out of order
panic!("Wrong index {i}, expected {next_index}");
}
op(value);
next_index += 1;
}
}
Upvotes: 2
Reputation: 43743
As @ChayimFriedman mentioned, this isn't necessarily feasible because rayon
likes to subdivide work starting in large chunks, so the order won't be friendly. However, because you are using .par_bridge()
, Rayon must take items from the Iterator
in order, so the order will be close to the original order. Therefore, it is feasible to recover the original order using a buffer and .enumerate()
, without consuming large amounts of memory.
Here is a demonstration program.
use std::collections::HashMap;
use std::sync::mpsc;
use std::time::Duration;
use rand::Rng;
use rayon::prelude::{ParallelBridge, ParallelIterator};
fn main() {
let data_source = (0..500u32).rev();
// Channel with enough capacity to hold an item from each thread
let (tx, rx) = mpsc::sync_channel(std::thread::available_parallelism().map_or(8, |n| n.get()));
rayon::scope(|s| {
s.spawn(move |_| {
data_source
.enumerate()
.par_bridge()
.map(|(i, value)| {
// pretend to do some work
std::thread::sleep(Duration::from_millis(
1000 + rand::thread_rng().gen_range(0..10),
));
(i, value)
})
.for_each_with(tx, |tx, pair| {
let _ = tx.send(pair);
});
});
recover_order(rx, emit);
});
}
fn emit(item: u32) {
println!("done with {item}");
}
fn recover_order<T>(rx: mpsc::Receiver<(usize, T)>, mut op: impl FnMut(T)) {
let mut next_index: usize = 0;
let mut buffer: HashMap<usize, T> = HashMap::new();
for (i, value) in rx {
if i == next_index {
op(value);
next_index += 1;
while let Some((_, value)) = buffer.remove_entry(&next_index) {
op(value);
next_index += 1;
}
} else {
// Item is out of order
buffer.insert(i, value);
}
}
assert!(buffer.is_empty(), "channel closed with missing items");
println!("Buffer capacity used: {}", buffer.capacity());
}
The for_each_with()
transfers items from Rayon control to the channel, and the recover_order()
function consumes the channel to call emit()
with the items in proper order.
The use of rayon::scope()
and spawn()
allows the for_each_with()
parallel iteration to run “in the background” on the existing Rayon thread pool, so that the current thread can handle the receiving directly.
Upvotes: 3