lulijun
lulijun

Reputation: 575

How to fire async callback in rust

I'm trying to implement a StateMachine in Rust, but I encountered some problems while trying to fire the callback of StateMachine in a spawn thread.

Here is my StateMachine struct. The state is a generic T because I want to use it in many different scenerios, and I use a Vec to store all the callbacks those registered into this StateMachine.

At the very begining, I didn't use the lifetime 'a, but it will run into some lifetime problems, so I add the lifetime 'a by this suggestion: Idiomatic callbacks in Rust

pub struct StateMachine<'a, T> where T:Clone+Eq+'a {
    state: RwLock<T>,
    listeners2: Vec<Arc<Mutex<ListenerCallback<'a, T>>>>,
}

pub type ListenerCallback<'a, T> = dyn FnMut(T) -> Result<()> + Send + Sync + 'a ;

When the state is changed, the StateMachine will fire all the callbacks, as follows.

pub async fn try_set(&mut self, new_state:T) -> Result<()> {
        if (block_on(self.state.read()).deref().eq(&new_state)) {
            return Ok(())
        }
        // todo change the state

        // fire every listener in spawn
        let mut fire_results = vec![];
        for listener in &mut self.listeners2 {
            let state = new_state.clone();
            let fire_listener = listener.clone();
            fire_results.push(tokio::spawn(async move {
                let mut guard  = fire_listener.lock().unwrap();
                guard.deref_mut()(state);
            }));
        }
        // if fire result return Err, return it
        for fire_result in fire_results {
            fire_result.await?;
        }
        Ok(())
    }

But it will cause a compilation error.

error[E0521]: borrowed data escapes outside of associated function
  --> src/taf/taf-core/src/execution/state_machine.rs:54:33
   |
15 | impl<'a,T> StateMachine<'a,T> where T:Clone+Eq+Send {
   |      -- lifetime `'a` defined here
...
34 |     pub async fn try_set(&mut self, new_state:T) -> Result<()> {
   |                          --------- `self` is a reference that is only valid in the associated function body
...
54 |             let fire_listener = listener.clone();
   |                                 ^^^^^^^^^^^^^^^^
   |                                 |
   |                                 `self` escapes the associated function body here
   |                                 argument requires that `'a` must outlive `'static`

##########################################################

The full code is coupled with a lot of business logic, so I rewrite 2 demos as follows, the problems is the same. The first demo fire callback synchronously and it works, the second demo try to fire callback asynchronously, it encounter the same problem: self escapes the associated function body here.

First demo(it works):

use std::alloc::alloc;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex, RwLock};
use anyhow::Result;
use dashmap::DashMap;

struct StateMachine<'a,T> where T:Clone+Eq+'a {
    state: T,
    listeners: Vec<Box<Callback<'a, T>>>,
}

type Callback<'a, T> = dyn FnMut(T) -> Result<()> + Send + Sync + 'a;

impl<'a, T> StateMachine<'a,T> where T:Clone+Eq+'a {

    pub fn new(init_state: T) -> Self {
        StateMachine {
            state: init_state,
            listeners: vec![]
        }
    }

    pub fn add_listener(&mut self, listener: Box<Callback<'a, T>>) -> Result<()> {
        self.listeners.push(listener);
        Ok(())
    }

    pub fn set(&mut self, new_state: T) -> Result<()> {

        self.state = new_state.clone();

        for listener in &mut self.listeners {
            listener(new_state.clone());
        }
        Ok(())
    }
}

#[derive(Clone, Eq, PartialEq, Hash)]
enum ExeState {
    Waiting,
    Running,
    Finished,
    Failed,
}

struct Execution<'a> {
    exec_id: String,
    pub state_machine: StateMachine<'a, ExeState>,
}

struct ExecManager<'a> {
    all_jobs: Arc<RwLock<DashMap<String, Execution<'a>>>>,
    finished_jobs: Arc<RwLock<Vec<String>>>,
}

impl<'a> ExecManager<'a> {

    pub fn new() -> Self {
        ExecManager {
            all_jobs: Arc::new(RwLock::new(DashMap::new())),
            finished_jobs: Arc::new(RwLock::new(vec![]))
        }
    }

    fn add_job(&mut self, job_id: String) {
        let mut execution = Execution {
            exec_id: job_id.clone(),
            state_machine: StateMachine::new(ExeState::Waiting)
        };

        // add listener
        let callback_finished_jobs = self.finished_jobs.clone();
        let callback_job_id = job_id.clone();
        execution.state_machine.add_listener( Box::new(move |new_state| {
            println!("listener fired!, job_id {}", callback_job_id.clone());
            if new_state == ExeState::Finished || new_state == ExeState::Failed {
                let mut guard = callback_finished_jobs.write().unwrap();
                guard.deref_mut().push(callback_job_id.clone());

            }
            Ok(())
        }));

        let mut guard = self.all_jobs.write().unwrap();
        guard.deref_mut().insert(job_id, execution);
    }

    fn mock_exec(&mut self, job_id: String) {
        let mut guard = self.all_jobs.write().unwrap();
        let mut exec = guard.deref_mut().get_mut(&job_id).unwrap();

        exec.state_machine.set(ExeState::Finished);
    }

}


#[test]
fn test() {
    let mut manager = ExecManager::new();

    manager.add_job(String::from("job_id1"));
    manager.add_job(String::from("job_id2"));

    manager.mock_exec(String::from("job_id1"));
    manager.mock_exec(String::from("job_id2"));


}

Second demo:

use std::alloc::alloc;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex, RwLock};
use anyhow::Result;
use dashmap::DashMap;
use petgraph::algo::astar;

struct StateMachine<'a,T> where T:Clone+Eq+Send+'a {
    state: T,
    listeners: Vec<Arc<Mutex<Box<Callback<'a, T>>>>>,
}

type Callback<'a, T> = dyn FnMut(T) -> Result<()> + Send + Sync + 'a;

impl<'a, T> StateMachine<'a,T> where T:Clone+Eq+Send+'a {

    pub fn new(init_state: T) -> Self {
        StateMachine {
            state: init_state,
            listeners: vec![]
        }
    }

    pub fn add_listener(&mut self, listener: Box<Callback<'a, T>>) -> Result<()> {
        self.listeners.push(Arc::new(Mutex::new(listener)));
        Ok(())
    }

    pub fn set(&mut self, new_state: T) -> Result<()> {

        self.state = new_state.clone();

        for listener in &mut self.listeners {
            let spawn_listener = listener.clone();
            tokio::spawn(async move {
                let mut guard = spawn_listener.lock().unwrap();
                guard.deref_mut()(new_state.clone());
            });
        }
        Ok(())
    }
}

#[derive(Clone, Eq, PartialEq, Hash)]
enum ExeState {
    Waiting,
    Running,
    Finished,
    Failed,
}

struct Execution<'a> {
    exec_id: String,
    pub state_machine: StateMachine<'a, ExeState>,
}

struct ExecManager<'a> {
    all_jobs: Arc<RwLock<DashMap<String, Execution<'a>>>>,
    finished_jobs: Arc<RwLock<Vec<String>>>,
}

impl<'a> ExecManager<'a> {

    pub fn new() -> Self {
        ExecManager {
            all_jobs: Arc::new(RwLock::new(DashMap::new())),
            finished_jobs: Arc::new(RwLock::new(vec![]))
        }
    }

    fn add_job(&mut self, job_id: String) {
        let mut execution = Execution {
            exec_id: job_id.clone(),
            state_machine: StateMachine::new(ExeState::Waiting)
        };

        // add listener
        let callback_finished_jobs = self.finished_jobs.clone();
        let callback_job_id = job_id.clone();
        execution.state_machine.add_listener( Box::new(move |new_state| {
            println!("listener fired!, job_id {}", callback_job_id.clone());
            if new_state == ExeState::Finished || new_state == ExeState::Failed {
                let mut guard = callback_finished_jobs.write().unwrap();
                guard.deref_mut().push(callback_job_id.clone());

            }
            Ok(())
        }));

        let mut guard = self.all_jobs.write().unwrap();
        guard.deref_mut().insert(job_id, execution);
    }

    fn mock_exec(&mut self, job_id: String) {
        let mut guard = self.all_jobs.write().unwrap();
        let mut exec = guard.deref_mut().get_mut(&job_id).unwrap();

        exec.state_machine.set(ExeState::Finished);
    }

}


#[test]
fn test() {
    let mut manager = ExecManager::new();

    manager.add_job(String::from("job_id1"));
    manager.add_job(String::from("job_id2"));

    manager.mock_exec(String::from("job_id1"));
    manager.mock_exec(String::from("job_id2"));


}

Compile error of second demo:

error[E0521]: borrowed data escapes outside of associated function
  --> generic/src/callback2.rs:34:34
   |
15 | impl<'a, T> StateMachine<'a,T> where T:Clone+Eq+Send+'a {
   |      -- lifetime `'a` defined here
...
29 |     pub fn set(&mut self, new_state: T) -> Result<()> {
   |                --------- `self` is a reference that is only valid in the associated function body
...
34 |             let spawn_listener = listener.clone();
   |                                  ^^^^^^^^^^^^^^^^
   |                                  |
   |                                  `self` escapes the associated function body here
   |                                  argument requires that `'a` must outlive `'static`
   |
   = note: requirement occurs because of the type `std::sync::Mutex<Box<dyn FnMut(T) -> Result<(), anyhow::Error> + Send + Sync>>`, which makes the generic argument `Box<dyn FnMut(T) -> Result<(), anyhow::Error> + Send + Sync>` invariant
   = note: the struct `std::sync::Mutex<T>` is invariant over the parameter `T`
   = help: see <https://doc.rust-lang.org/nomicon/subtyping.html> for more information about variance

Upvotes: 0

Views: 2014

Answers (1)

Kevin Reid
Kevin Reid

Reputation: 43743

Tasks spawned with tokio::spawn() cannot use borrowed data (here, the data with lifetime 'a, whatever it may be). This is because there is not currently (and likely will never be) any way to guarantee that the borrowed data reliably outlives the spawned task.

You have two choices:

  1. Fire the notifications without spawning. You can put the notification futures into a FuturesUnordered to run them all concurrently, but they will still all have to finish before try_set() does.

  2. Remove the lifetime parameter; stop allowing callbacks that borrow data. Put 'static on your dyn types where necessary. Change the users of the StateMachine so they do not try to use borrowed data but use Arc instead, if necessary.

    pub struct StateMachine<T> where T: Clone + Eq + 'static {
       state: RwLock<T>,
       listeners2: Vec<Arc<Mutex<ListenerCallback<T>>>>,
    }
    
    pub type ListenerCallback<T> = dyn FnMut(T) -> Result<()> + Send + Sync + 'static;
    

Upvotes: 1

Related Questions