Wouter De Coster
Wouter De Coster

Reputation: 528

Rust: After parallelization with rayon, write out results in order without waiting until the end

I am using rayon's .par_bridge().for_each(|r| { do_work(r) } ) to run tasks in parallel for some iterator (specifically: Records from a bed file, but I don't think that matters). There could be up to ~700.000 tasks.

I want to print (stdout or to a file) the results of every call to do_work(), but do this printing only in the order of the original iterator. I could sort all output after all parallel jobs have been completed, but storing all results until the end will require much more memory. I could add .enumerate() to have an index for each item and print out the first one when it is done, storing the rest until it is their turn, but I am not sure how to best implement such a system, or if it is the best solution at all. What would you suggest?

Upvotes: 1

Views: 1450

Answers (2)

Nahor
Nahor

Reputation: 101

Here is a version adapted from @kevin-reid's.
As mentioned in my comment to his post, with his code, if a thread is slow, the others will keep on working anyway, resulting in more and more data being stored in the buffer. In a worst case scenario, all the data can get stored in that buffer but the first.

My adaptation will block a thread from sending its result until it's its turn. This means that at most, there will be one result per thread in memory (plus whatever is waiting in the sync_channel).
The main drawback is the need for the Mutex and notify_all, which can have a significant cost when the work can be done quickly.

It's also possible to combine both answers to allow the threads to run ahead but within reason, but I leave that as an exercise 😜.

use std::{
    sync::{mpsc, Arc, Condvar, Mutex},
    time::Duration,
};

use rand::Rng;
use rayon::prelude::{ParallelBridge, ParallelIterator};

fn main() {
    let data_source = (0..500u32).rev();

    // Channel with enough capacity to hold an item from each thread
    let (tx, rx) = mpsc::sync_channel(std::thread::available_parallelism().map_or(8, |n| n.get()));

    let gate = Arc::new((Mutex::new(0), Condvar::new()));

    rayon::scope(|s| {
        s.spawn(move |_| {
            data_source
                .enumerate()
                .par_bridge()
                .map(|(i, value)| {
                    // pretend to do some work
                    // For worst case scenario, replace with:
                    //    if i == 0 { std::thread::sleep(Duration::from_millis(1000)); }
                    std::thread::sleep(Duration::from_millis(
                        1000 + rand::thread_rng().gen_range(0..10),
                    ));

                    (i, value)
                })
                .for_each_with((tx, gate), |(tx, gate), (i, value)| {
                    let (lock, cond) = &**gate;

                    {
                        let mut guard = cond.wait_while(lock.lock().unwrap(), |v| *v < i).unwrap();
                        let _ = tx.send((i, value));
                        *guard = i + 1;
                    }

                    cond.notify_all();
                });
        });

        recover_order(rx, emit);
    });
}

fn emit(item: u32) {
    println!("done with {item}");
}

fn recover_order<T>(rx: mpsc::Receiver<(usize, T)>, mut op: impl FnMut(T)) {
    let mut next_index: usize = 0;
    for (i, value) in rx {
        if i != next_index {
            // Item is out of order
            panic!("Wrong index {i}, expected {next_index}");
        }
        op(value);
        next_index += 1;
    }
}

Upvotes: 2

Kevin Reid
Kevin Reid

Reputation: 43743

As @ChayimFriedman mentioned, this isn't necessarily feasible because rayon likes to subdivide work starting in large chunks, so the order won't be friendly. However, because you are using .par_bridge(), Rayon must take items from the Iterator in order, so the order will be close to the original order. Therefore, it is feasible to recover the original order using a buffer and .enumerate(), without consuming large amounts of memory.

Here is a demonstration program.

use std::collections::HashMap;
use std::sync::mpsc;
use std::time::Duration;

use rand::Rng;
use rayon::prelude::{ParallelBridge, ParallelIterator};

fn main() {
    let data_source = (0..500u32).rev();

    // Channel with enough capacity to hold an item from each thread
    let (tx, rx) = mpsc::sync_channel(std::thread::available_parallelism().map_or(8, |n| n.get()));

    rayon::scope(|s| {
        s.spawn(move |_| {
            data_source
                .enumerate()
                .par_bridge()
                .map(|(i, value)| {
                    // pretend to do some work
                    std::thread::sleep(Duration::from_millis(
                        1000 + rand::thread_rng().gen_range(0..10),
                    ));
                    (i, value)
                })
                .for_each_with(tx, |tx, pair| {
                    let _ = tx.send(pair);
                });
        });

        recover_order(rx, emit);
    });
}

fn emit(item: u32) {
    println!("done with {item}");
}

fn recover_order<T>(rx: mpsc::Receiver<(usize, T)>, mut op: impl FnMut(T)) {
    let mut next_index: usize = 0;
    let mut buffer: HashMap<usize, T> = HashMap::new();
    for (i, value) in rx {
        if i == next_index {
            op(value);
            next_index += 1;
            while let Some((_, value)) = buffer.remove_entry(&next_index) {
                op(value);
                next_index += 1;
            }
        } else {
            // Item is out of order
            buffer.insert(i, value);
        }
    }

    assert!(buffer.is_empty(), "channel closed with missing items");

    println!("Buffer capacity used: {}", buffer.capacity());
}

The for_each_with() transfers items from Rayon control to the channel, and the recover_order() function consumes the channel to call emit() with the items in proper order.

The use of rayon::scope() and spawn() allows the for_each_with() parallel iteration to run “in the background” on the existing Rayon thread pool, so that the current thread can handle the receiving directly.

Upvotes: 3

Related Questions