jhscheer
jhscheer

Reputation: 351

How to create a ring communication between threads using mpsc channels?

I want to spawn n threads with the ability to communicate to other threads in a ring topology, e.g. thread 0 can send messages to thread 1, thread 1 to thread 2, etc. and thread n to thread 0.

This is an example of what I want to achieve with n=3:

use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;

let (tx0, rx0): (Sender<i32>, Receiver<i32>) = mpsc::channel();
let (tx1, rx1): (Sender<i32>, Receiver<i32>) = mpsc::channel();
let (tx2, rx2): (Sender<i32>, Receiver<i32>) = mpsc::channel();

let child0 = thread::spawn(move || {
    tx0.send(0).unwrap();
    println!("thread 0 sent: 0");
    println!("thread 0 recv: {:?}", rx2.recv().unwrap());
});
let child1 = thread::spawn(move || {
    tx1.send(1).unwrap();
    println!("thread 1 sent: 1");
    println!("thread 1 recv: {:?}", rx0.recv().unwrap());
});
let child2 = thread::spawn(move || {
    tx2.send(2).unwrap();
    println!("thread 2 sent: 2");
    println!("thread 2 recv: {:?}", rx1.recv().unwrap());
});

child0.join();
child1.join();
child2.join();

Here I create channels in a loop, store them in a vector, reorder the senders, store them in a new vector and then spawn threads each with their own Sender-Receiver (tx1/rx0, tx2/rx1, etc.) pair.

const NTHREADS: usize = 8;

// create n channels
let channels: Vec<(Sender<i32>, Receiver<i32>)> =
    (0..NTHREADS).into_iter().map(|_| mpsc::channel()).collect();

// switch tupel entries for the senders to create ring topology
let mut channels_ring: Vec<(Sender<i32>, Receiver<i32>)> = (0..NTHREADS)
    .into_iter()
    .map(|i| {
        (
            channels[if i < channels.len() - 1 { i + 1 } else { 0 }].0,
            channels[i].1,
        )
    })
    .collect();

let mut children = Vec::new();
for i in 0..NTHREADS {
    let (tx, rx) = channels_ring.remove(i);

    let child = thread::spawn(move || {
        tx.send(i).unwrap();
        println!("thread {} sent: {}", i, i);
        println!("thread {} recv: {:?}", i, rx.recv().unwrap());
    });

    children.push(child);
}

for child in children {
    let _ = child.join();
}

This doesn't work, because Sender cannot be copied to create a new vector. However, if I use refs (& Sender):

let mut channels_ring: Vec<(&Sender<i32>, Receiver<i32>)> = (0..NTHREADS)
    .into_iter()
    .map(|i| {
        (
            &channels[if i < channels.len() - 1 { i + 1 } else { 0 }].0,
            channels[i].1,
        )
    })
    .collect();

I cannot spawn the threads, because std::sync::mpsc::Sender<i32> cannot be shared between threads safely.

Upvotes: 4

Views: 918

Answers (2)

user4815162342
user4815162342

Reputation: 154911

This doesn't work, because Sender cannot be copied to create a new vector. However, if I use refs (& Sender):

While it's true that Sender cannot be copied, it does implement Clone, so you can always clone it manually. But that approach won't work for Receiver, which is not Clone and which you also need to extract from the vector.

The problem with your first code is that you cannot use let foo = vec[i] to move just one value out of a vector of non-Copy values. That would leave the vector in an invalid state, with one element invalid, subsequent access to which would cause undefined behavior. For this to work, Vec would need to track which elements were moved and which not, which would impose a cost on all Vecs. So instead, Vec disallows moving an element out of it, leaving it to the user to track moves.

A simple way to move a value out of Vec is to replace Vec<T> with Vec<Option<T>> and use Option::take. foo = vec[i] is replaced with foo = vec[i].take().unwrap(), which moves the T value from the option in vec[i] (while asserting that it's not None) and leaves None, a valid variant of Option<T>, in the vector. Here is your first attempt modified in that manner (playground):

const NTHREADS: usize = 8;

let channels_ring: Vec<_> = {
    let mut channels: Vec<_> = (0..NTHREADS)
        .into_iter()
        .map(|_| {
            let (tx, rx) = mpsc::channel();
            (Some(tx), Some(rx))
        })
        .collect();

    (0..NTHREADS)
        .into_iter()
        .map(|rxpos| {
            let txpos = if rxpos < NTHREADS - 1 { rxpos + 1 } else { 0 };
            (
                channels[txpos].0.take().unwrap(),
                channels[rxpos].1.take().unwrap(),
            )
        })
        .collect()
};

let children: Vec<_> = channels_ring
    .into_iter()
    .enumerate()
    .map(|(i, (tx, rx))| {
        thread::spawn(move || {
            tx.send(i as i32).unwrap();
            println!("thread {} sent: {}", i, i);
            println!("thread {} recv: {:?}", i, rx.recv().unwrap());
        })
    })
    .collect();

for child in children {
    child.join().unwrap();
}

Upvotes: 3

Peter Hall
Peter Hall

Reputation: 58735

Senders and Receivers cannot be shared so you need to move them into their respective threads. That means removing them from the Vec or else consuming the Vec while iterating it - the vector is not permitted to be in an invalid state (with holes), even as an intermediate step. Iterating over the vectors with into_iter will achieve that by consuming them.

A little trick you can use to get the the senders and receivers to pair up in a cycle, is to create two vectors; one of senders and one of receivers; and then rotate one so that the same index into each vector will give you the pairs you want.

use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;

fn main() {
    const NTHREADS: usize = 8;

    // create n channels
    let (mut senders, receivers): (Vec<Sender<i32>>, Vec<Receiver<i32>>) =
        (0..NTHREADS).into_iter().map(|_| mpsc::channel()).unzip();

    // move the first sender to the back
    senders.rotate_left(1);

    let children: Vec<_> = senders
        .into_iter()
        .zip(receivers.into_iter())
        .enumerate()
        .map(|(i, (tx, rx))| {
            thread::spawn(move || {
                tx.send(i as i32).unwrap();
                println!("thread {} sent: {}", i, i);
                println!("thread {} recv: {:?}", i, rx.recv().unwrap());
            })
        })
        .collect();

    for child in children {
        let _ = child.join();
    }
}

Upvotes: 7

Related Questions