Gatonito
Gatonito

Reputation: 1894

try_lock on futures::lock::Mutex outside of async?

I'm trying to implement Async read for a struct that has a futures::lock::Mutex:

pub struct SmolSocket<'a> {
    stack: Arc<futures::lock::Mutex<SmolStackWithDevice<'a>>>,
}

impl<'a> AsyncRead for SmolSocket<'a>  {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>
    ) -> Poll<std::io::Result<()>> {
        block_on(self.stack).read(...)
    }
}

The problem is that, since poll_read is not async, I cannot call await. But I also don't want to, as it'd block. I could call try_lock to try and if not, I'd register a Waker to be called by SmolSocket in the future.

Since I cannot do that either because it's not async, is there a version of block_on that does the same as try_lock for futures::lock::Mutex outside of async?

Upvotes: 1

Views: 1171

Answers (1)

Mihail Malostanidis
Mihail Malostanidis

Reputation: 3024

You probably mean to poll the MutexLockFuture instead, this can for example be done with the core::task::ready! macro, which desugars as following:

let num = match fut.poll(cx) {
    Poll::Ready(t) => t,
    Poll::Pending => return Poll::Pending,
};

To poll a future, you also need to pin it (ensure it doesn't get moved). This can be done on the stack with tokio::pin!, or Pin::new if the type is already Unpin (MutexLockFuture is), or by moving onto the heap with Box::pin.

Below is a runnable example.

⚠️ KEEP READING TO SEE WHY YOU DON'T WANT TO DO THIS!

#![feature(ready_macro)]
use core::{
    future::Future,
    pin::Pin,
    task::{ready, Context, Poll},
};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt};
pub struct SmolStackWithDevice<'a> {
    counter: usize,
    data: &'a [u8],
}
impl<'a> AsyncRead for SmolStackWithDevice<'a> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        if self.counter % 2 == 0 {
            self.counter += 1;
            cx.waker().wake_by_ref();
            println!("read nothing");
            return Poll::Pending;
        }
        buf.put_slice(&[self.data[self.counter / 2]]);
        self.counter += 1;
        println!("read something");
        Poll::Ready(Ok(()))
    }
}
pub struct SmolSocket<'a> {
    stack: Arc<futures::lock::Mutex<SmolStackWithDevice<'a>>>,
}
impl<'a> AsyncRead for SmolSocket<'a> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        let mut lock_fut = self.stack.lock();
        let pinned_lock_fut = Pin::new(&mut lock_fut);
        let mut guard = ready!(pinned_lock_fut.poll(cx));
        println!("acquired lock");
        let pinned_inner = Pin::new(&mut *guard);
        pinned_inner.poll_read(cx, buf)
    }
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
    let data = b"HORSE";
    let mut buf = [0; 5];
    let mut s = SmolSocket {
        stack: Arc::new(
            SmolStackWithDevice {
                counter: 0,
                data: &data[..],
            }
            .into(),
        ),
    };
    s.read_exact(&mut buf).await.unwrap();
    println!("{}", String::from_utf8_lossy(&buf));
}

Look at it go! (in Rust Playground)

⚠️ KEEP READING TO SEE WHY YOU DON'T WANT TO DO THIS!

So, what is the problem? Well, as you can see from the output, whenever we succeed in acquiring the lock, but the underlying source is not ready to read, or only gives us a small read, we drop the lock, and on the next poll we will have to acquire it again.

This is a good point to remember that async flavors of Mutex are only recommended over std or parking_lot when it is expected that the Guard from a successful locking will be held across an await, or explicitly stored in a Future data structure.

We are not doing that here, we are only ever exercising the fast path equivalent to Mutex::try_lock, because whenever the lock is not immediately available, we drop the MutexLockFuture instead of waiting to be waked to poll it again.

However, storing the lock in the data structure would make it easy to accidentally deadlock. So a good design might be creating an awkward-to-store(borrowing) AsyncRead adapter that wraps the lock:

pub struct SmolSocket<'a> {
    stack: Arc<futures::lock::Mutex<SmolStackWithDevice<'a>>>,
}
impl<'a> SmolSocket<'a> {
    fn read(&'a self) -> Reader<'a> {
        Reader::Locking(self.stack.lock())
    }
}
pub enum Reader<'a> {
    Locking(futures::lock::MutexLockFuture<'a, SmolStackWithDevice<'a>>),
    Locked(futures::lock::MutexGuard<'a, SmolStackWithDevice<'a>>),
}
impl<'a> AsyncRead for Reader<'a> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        let this = self.get_mut();
        match this {
            Reader::Locking(f) => {
                *this = Reader::Locked(ready!(Pin::new(f).poll(cx)));
                println!("acquired lock");
                Pin::new(this).poll_read(cx, buf)
            }
            Reader::Locked(l) => Pin::new(&mut **l).poll_read(cx, buf),
        }
    }
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
    let data = b"HORSE";
    let mut buf = [0; 5];
    let s = SmolSocket {
        stack: Arc::new(
            SmolStackWithDevice {
                counter: 0,
                data: &data[..],
            }
            .into(),
        ),
    };
    s.read().read_exact(&mut buf).await.unwrap();
    println!("{}", String::from_utf8_lossy(&buf));
}

Look at it go! (executable Playground link)

This works out, because both the LockFuture and our SmolStackWithDevice are Unpin (non-self-referential) and so we don't have to guarantee we aren't moving them.

In a general case, for example if your SmolStackWithDevice is not Unpin, you'd have to project the Pin like this:

unsafe {
    let this = self.get_unchecked_mut();
    match this {
        Reader::Locking(f) => {
            *this = Reader::Locked(ready!(Pin::new_unchecked(f).poll(cx)));
            println!("acquired lock");
            Pin::new_unchecked(this).poll_read(cx, buf)
        }
        Reader::Locked(l) => Pin::new_unchecked(&mut **l).poll_read(cx, buf),
    }
}

Not sure how to encapsulate the unsafety, pin_project isn't enough here, as we also need to dereference the guard.

But this only acquires the lock once, and drops it when the Reader is dropped, so, great success.

You can also see that it doesn't deadlock if you do

    let mut r1 = s.read();
    let mut r2 = s.read();
    r1.read_exact(&mut buf[..3]).await.unwrap();
    drop(r1);
    r2.read_exact(&mut buf[3..]).await.unwrap();
    println!("{}", String::from_utf8_lossy(&buf));

This is only possible because we deferred locking until polling.

Upvotes: 1

Related Questions