Trần Kim Dự
Trần Kim Dự

Reputation: 6112

How to start and stop a worker thread

I have a following requirement which is standard in other programming languages but I don't know how to do in Rust.

I have a class, I want to write a method to spawn a worker thread that satisfied 2 conditions:

For example, here is my dummy code:

struct A {
    thread: JoinHandle<?>,
}

impl A {
    pub fn run(&mut self) -> Result<()>{
        self.thread = thread::spawn(move || {
            let mut i = 0;
            loop {
                self.call();
                i = 1 + i;
                if i > 5 {
                    return
                }
            }
        });
        Ok(())
    }

    pub fn stop(&mut self) -> std::thread::Result<_> {
        self.thread.join()
    }

    pub fn call(&mut self) {
        println!("hello world");
    }
}

fn main() {
    let mut a = A{};
    a.run();
}

I have an error at thread: JoinHandle<?>. What is the type of thread in this case. And is my code correct to start and stop a worker thread?

Upvotes: 1

Views: 2149

Answers (2)

vallentin
vallentin

Reputation: 26245

In short, the T in join() on a JoinHandle<T> returns the result of the closure passed to thread::spawn(). So in your case JoinHandle<?> would need to be JoinHandle<()> as your closure returns nothing, i.e. () (unit).

Other than that, your dummy code contains a few additional issues.

  • The return type of run() is incorrect, and would need to at least be Result<(), ()>.
  • The thread field would need to be Option<JoinHandle<()> to be able to handle fn stop(&mut self) as join() consumes the JoinHandle.
  • However, you're attempting to pass &mut self to the closure, which brings a lot more issues, boiling down to multiple mutable references
    • This could be solved with e.g. Mutex<A>. However, if you call stop() then that could lead to a deadlock instead.

However, since it was dummy code, and you clarified in the comments. Let me try and clarify what you meant with a few examples. This includes me rewriting your dummy code.

Result after worker is done

If you don't need access to the data while the worker thread is running, then you can make a new struct WorkerData. Then in run() you copy/clone the data you need from A (or as I've renamed it Worker). Then in the closure you finally return data again, so you can acquire it through join().

use std::thread::{self, JoinHandle};

struct WorkerData {
    ...
}

impl WorkerData {
    pub fn call(&mut self) {
        println!("hello world");
    }
}

struct Worker {
    thread: Option<JoinHandle<WorkerData>>,
}

impl Worker {
    pub fn new() -> Self {
        Self { thread: None }
    }

    pub fn run(&mut self) {
        // Create `WorkerData` and copy/clone whatever is needed from `self`
        let mut data = WorkerData {};

        self.thread = Some(thread::spawn(move || {
            let mut i = 0;
            loop {
                data.call();
                i = 1 + i;
                if i > 5 {
                    // Return `data` so we get in through `join()`
                    return data;
                }
            }
        }));
    }

    pub fn stop(&mut self) -> Option<thread::Result<WorkerData>> {
        if let Some(handle) = self.thread.take() {
            Some(handle.join())
        } else {
            None
        }
    }
}

You don't really need thread to be Option<JoinHandle<WorkerData>> and instead could just use JoinHandle<WorkerData>>. Because if you wanted to call run() again, it would just be easier to reassign the variable holding the Worker.

So now we can simplify Worker, removing the Option and change stop to consume thread instead, along with creating new() -> Self in place of run(&mut self).

use std::thread::{self, JoinHandle};

struct Worker {
    thread: JoinHandle<WorkerData>,
}

impl Worker {
    pub fn new() -> Self {
        // Create `WorkerData` and copy/clone whatever is needed from `self`
        let mut data = WorkerData {};

        let thread = thread::spawn(move || {
            let mut i = 0;
            loop {
                data.call();
                i = 1 + i;
                if i > 5 {
                    return data;
                }
            }
        });

        Self { thread }
    }

    pub fn stop(self) -> thread::Result<WorkerData> {
        self.thread.join()
    }
}

Shared WorkerData

If you want to retain references to WorkerData between multiple threads, then you'd need to use Arc. Since you additionally want to be able to mutate it, you'll need to use a Mutex.

If you'll only be mutating within a single thread, then you could alternatively you a RwLock, which compared to a Mutex will allow you to lock and obtain multiple immutable references at the same time.

use std::sync::{Arc, RwLock};
use std::thread::{self, JoinHandle};

struct Worker {
    thread: JoinHandle<()>,
    data: Arc<RwLock<WorkerData>>,
}

impl Worker {
    pub fn new() -> Self {
        // Create `WorkerData` and copy/clone whatever is needed from `self`
        let data = Arc::new(RwLock::new(WorkerData {}));

        let thread = thread::spawn({
            let data = data.clone();
            move || {
                let mut i = 0;
                loop {
                    if let Ok(mut data) = data.write() {
                        data.call();
                    }

                    i = 1 + i;
                    if i > 5 {
                        return;
                    }
                }
            }
        });

        Self { thread, data }
    }

    pub fn stop(self) -> thread::Result<Arc<RwLock<WorkerData>>> {
        self.thread.join()?;
        // You might be able to unwrap and get the inner `WorkerData` here
        Ok(self.data)
    }
}

If you add a method to be able to obtain data in the form of Arc<RwLock<WorkerData>>. Then if you clone the Arc and lock it (the inner RwLock) prior to calling stop(), then that would result in a deadlock. To avoid that, any data() method should return &WorkerData or &mut WorkerData instead of the Arc. That way you'd be unable to call stop() and cause a deadlock.

Flag to stop worker

If you actually want to stop the worker thread, then you'd have to use a flag to signal it to do so. You can create a flag in the form of a shared AtomicBool.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::thread::{self, JoinHandle};

struct Worker {
    thread: JoinHandle<()>,
    data: Arc<RwLock<WorkerData>>,
    stop_flag: Arc<AtomicBool>,
}

impl Worker {
    pub fn new() -> Self {
        // Create `WorkerData` and copy/clone whatever is needed from `self`
        let data = Arc::new(RwLock::new(WorkerData {}));

        let stop_flag = Arc::new(AtomicBool::new(false));

        let thread = thread::spawn({
            let data = data.clone();
            let stop_flag = stop_flag.clone();
            move || {
                // let mut i = 0;
                loop {
                    if stop_flag.load(Ordering::Relaxed) {
                        break;
                    }

                    if let Ok(mut data) = data.write() {
                        data.call();
                    }

                    // i = 1 + i;
                    // if i > 5 {
                    //     return;
                    // }
                }
            }
        });

        Self {
            thread,
            data,
            stop_flag,
        }
    }

    pub fn stop(self) -> thread::Result<Arc<RwLock<WorkerData>>> {
        self.stop_flag.store(true, Ordering::Relaxed);
        self.thread.join()?;
        // You might be able to unwrap and get the inner `WorkerData` here
        Ok(self.data)
    }
}

Multiple threads and multiple tasks

If you want multiple kinds of tasks processed, spread across multiple threads, then here's a more generalized example.

You already mentioned using mpsc. So you can use a Sender and Receiver along with a custom Task and TaskResult enum.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};

pub enum Task {
    ...
}

pub enum TaskResult {
    ...
}

pub type TaskSender = Sender<Task>;
pub type TaskReceiver = Receiver<Task>;

pub type ResultSender = Sender<TaskResult>;
pub type ResultReceiver = Receiver<TaskResult>;

struct Worker {
    threads: Vec<JoinHandle<()>>,
    task_sender: TaskSender,
    result_receiver: ResultReceiver,
    stop_flag: Arc<AtomicBool>,
}

impl Worker {
    pub fn new(num_threads: usize) -> Self {
        let (task_sender, task_receiver) = mpsc::channel();
        let (result_sender, result_receiver) = mpsc::channel();

        let task_receiver = Arc::new(Mutex::new(task_receiver));

        let stop_flag = Arc::new(AtomicBool::new(false));

        Self {
            threads: (0..num_threads)
                .map(|_| {
                    let task_receiver = task_receiver.clone();
                    let result_sender = result_sender.clone();
                    let stop_flag = stop_flag.clone();

                    thread::spawn(move || loop {
                        if stop_flag.load(Ordering::Relaxed) {
                            break;
                        }

                        let task_receiver = task_receiver.lock().unwrap();

                        if let Ok(task) = task_receiver.recv() {
                            drop(task_receiver);

                            // Perform the `task` here

                            // If the `Task` results in a `TaskResult` then create it and send it back
                            let result: TaskResult = ...;
                            // The `SendError` can be ignored as it only occurs if the receiver
                            // has already been deallocated
                            let _ = result_sender.send(result);
                        } else {
                            break;
                        }
                    })
                })
                .collect(),
            task_sender,
            result_receiver,
            stop_flag,
        }
    }

    pub fn stop(self) -> Vec<thread::Result<()>> {
        drop(self.task_sender);

        self.stop_flag.store(true, Ordering::Relaxed);

        self.threads
            .into_iter()
            .map(|t| t.join())
            .collect::<Vec<_>>()
    }

    #[inline]
    pub fn request(&mut self, task: Task) {
        self.task_sender.send(task).unwrap();
    }

    #[inline]
    pub fn result_receiver(&mut self) -> &ResultReceiver {
        &self.result_receiver
    }
}

An example of using the Worker along with sending tasks and receiving task results, would then look like this:

fn main() {
    let mut worker = Worker::new(4);

    // Request that a `Task` is performed
    worker.request(task);

    // Receive a `TaskResult` if any are pending
    if let Ok(result) = worker.result_receiver().try_recv() {
        // Process the `TaskResult`
    }
}

In a few cases you might need to implement Send for Task and/or TaskResult. Check out "Understanding the Send trait".

unsafe impl Send for Task {}
unsafe impl Send for TaskResult {}

Upvotes: 10

ddulaney
ddulaney

Reputation: 1101

The type parameter of a JoinHandle should be the return type of the thread's function.

In this case, the return type is an empty tuple (), pronounced unit. It is used when there is only one value possible, and is the implicit "return type" of functions when no return type is specified.

You can just write JoinHandle<()> to represent that the function will not return anything.

(Note: Your code will run into some borrow checker issues with self.call(), which will probably need to be solved with Arc<Mutex<Self>>, but that's another question.)

Upvotes: 1

Related Questions