Reputation: 1584
I am looking for a primitive that would make two threads wait until an event happens and then continue but if the event already happened then don't even stop waiting. I imagine that a watch channel may work but that does not seem very idiomatic.
Pure tokio notify may not work as an event may happen earlier than thread starts waiting for notification. So maybe adding a boolean wrapped in Mutex is necessary. But I want a standard idiomatic solution as I imagine the problem is not new. Basically, I want code similar to the following to work and have no race conditions.
use std::sync::Arc;
use tokio::sync::Notify;
#[tokio::main]
async fn main() {
let notify = Arc::new(Notify::new());
let notify2 = notify.clone();
let notify3 = notify.clone();
println!("sending notification");
notify.notify_waiters();
let handle = tokio::spawn(async move {
notify2.notified().await;
println!("received notification 1");
});
let handle = tokio::spawn(async move {
notify3.notified().await;
println!("received notification 2");
});
// Wait for task to receive notification.
handle.await.unwrap();
}
Upvotes: 0
Views: 120
Reputation: 532
Here's a potential solution leveraging a combination of semaphores and atomics
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
#[tokio::main]
async fn main() {
let event = Arc::new(EventNotifier::default());
let mut tasks = Vec::with_capacity(2);
for _ in 0..2 {
let task = tokio::spawn({
let event = event.clone();
async move {
// wait for event
event.notified().await;
// do something afterwards
}
});
tasks.push(task);
}
event.notify();
for t in tasks {
let _ = t.await;
}
}
struct EventNotifier {
// atomic state of event, first bit is a boolean flag,
// and the rest is the number of task, waiting for the event to happen
atomic: AtomicUsize,
// just an asynchronous signaling mechanism
permits: Semaphore,
}
impl Default for EventNotifier {
fn default() -> Self {
Self {
atomic: AtomicUsize::new(0),
permits: Semaphore::new(0),
}
}
}
impl EventNotifier {
pub fn notify(&self) {
let mut value = self.atomic.load(Ordering::Acquire);
let waiters = loop {
let waiters = value >> 1; // get the number of waiting tasks
// try to update the state of event to being complete
if let Err(v) =
self.atomic
.compare_exchange(value, 1, Ordering::Release, Ordering::Acquire)
{
// if some other task registered itself as waiting the CAS will fail
value = v;
} else {
break waiters;
}
};
// notify all the waiting tasks about event completion
self.permits.add_permits(waiters);
}
pub async fn notified(&self) {
let mut value = self.atomic.load(Ordering::Acquire);
loop {
// if event already completed, just return
if value & 1 == 1 {
return;
}
// otherwise try to register itself as waiting task
let waiters = (value >> 1) + 1;
let flag = value & 1;
let new = flag | (waiters << 1);
// if the event completed or some other task registered itself as completed retry the above check
if let Err(v) =
self.atomic
.compare_exchange(value, new, Ordering::Release, Ordering::Acquire)
{
value = v
} else {
break;
}
}
// wait for event completion notification asynchronously
let _ = self.permits.acquire().await;
}
}
The implementation should handle all the possible cases avoiding race conditions.
Upvotes: 1
Reputation: 1584
Here is a sketch of a solution with watch
channels. It seems working correctly and I will proceed for now. But I would be grateful for a more idiomatic solution with some sort of an analogue of condvar.
use rand::{thread_rng, Rng};
use tokio::sync::watch;
use tokio::time::{self, Duration};
#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
async fn main() {
const TEST_COUNT: usize = 10_000;
// We'll run the same "trigger an event once, let two tasks wait for it" flow 10k times.
for i in 0..TEST_COUNT {
let (tx, rx) = watch::channel(false);
let mut rx1 = rx.clone();
let mut rx2 = rx.clone();
// Task A
let handle1 = tokio::spawn(async move {
// Random short sleep to scramble scheduling.
let delay = thread_rng().gen_range(0..2);
time::sleep(Duration::from_millis(delay)).await;
// Check if event has already happened.
// the explicit “borrow” check is only an optimization that avoids an unnecessary async suspension
// in case the receiver is already up-to-date
if !*rx1.borrow() {
let delay = thread_rng().gen_range(0..2);
time::sleep(Duration::from_millis(delay)).await;
// If not, await first change to `true`.
// Under the hood, each watch::Receiver has an internal sequence number indicating the version
// of the channel’s state it has seen. Every time the sender calls tx.send(...),
// this version is incremented. When you call changed().await, if the receiver’s version is out of date
// (i.e., less than the channel’s current version), changed() returns immediately.
// This is how the watch channel prevents “missing updates” even if the change happens
// between your “check” and your “await.”
rx1.changed().await.expect("watch channel closed");
}
});
// Task B
let handle2 = tokio::spawn(async move {
let delay = thread_rng().gen_range(0..2);
time::sleep(Duration::from_millis(delay)).await;
if !*rx2.borrow() {
let delay = thread_rng().gen_range(0..2);
time::sleep(Duration::from_millis(delay)).await;
rx2.changed().await.expect("watch channel closed");
}
});
// Random short sleep before triggering event.
// This tries to ensure the tasks might already be waiting ...
let delay = thread_rng().gen_range(0..4);
time::sleep(Duration::from_millis(delay)).await;
// "Event has happened"
tx.send(true).expect("watch channel closed");
// Wait for both tasks to confirm receipt of the `true` state.
handle1.await.unwrap();
handle2.await.unwrap();
// Print progress occasionally
if (i + 1) % 1000 == 0 {
println!("Finished iteration {}", i + 1);
}
}
println!("All {} iterations completed successfully.", TEST_COUNT);
}
Upvotes: 0
Reputation: 1
I don't have much experience in Async, so take my words with a grain of salt.
I think for what you're describing the only way is as you described, to use a Arc<Mutex<T>>
wrapper for some sort of state you keep track of in your threads.
This code would look something like this:
use std::sync::Arc;
use tokio::sync::Mutex;
async fn notified_functions(name: &str, shared_state: Arc<Mutex<bool>>) {
loop {
if *shared_state.lock().await {
break;
}
}
println!("Notified[{name}].");
}
#[tokio::main]
async fn main() {
let shared_state = Arc::new(Mutex::new(false));
*shared_state.lock().await = true;
let mut handles = Vec::new();
handles.push(
tokio::spawn(notified_functions("1", shared_state.clone()))
);
handles.push(
tokio::spawn(notified_functions("2", shared_state.clone()))
);
//Wait for the threads to join.
for i in handles {
i.await.unwrap();
}
}
You can look into more option in case you need to send messages and have better performance, crossbeam and flume would be a good start. But I have zero experience with those and they might still have the race condition problem you were trying to fix.
I hope this helped. I know it's not the most satisfying answer but I think this is the idiomatic solution you were looking for and actually guessed on your original post.
Upvotes: -3