F.Chen
F.Chen

Reputation: 83

How to await for the first k futures?

Say we have n servers, and we only need k < n responses .

I understand futures::join_all can be used to wait for all n futures, but I want my program finish waiting after k responses.

Is there something similar to join_all that I can use to wait for the first k responses?

Upvotes: 8

Views: 2865

Answers (2)

Freyja
Freyja

Reputation: 40784

You can use streams (async iterators) for this. You can use FuturesUnordered as an unordered collection of futures, which can be used as a stream where you get each future's result in the order they complete. Then you can combine this with .take(n) to only take the first n items of the stream, and then .collect::<Vec<_>>() to wait for the stream to finish and collect the results in a Vec:

use futures::prelude::*;
use futures::stream::FuturesUnordered;

let futures = vec![
    // assume `f(n, t)` = sleep for `t` millis, then return `n`
    f(1, 1000),
    f(2, 10),
    f(3, 105),
    f(4, 40),
    f(5, 70),
    f(6, 270),
];

// create unordered collection of futures
let futures = futures.into_iter().collect::<FuturesUnordered<_>>();

// use collection as a stream, await only first 4 futures to complete
let first_4 = futures.take(4).collect::<Vec<_>>().await;

// note: any remaining futures will be cancelled automatically when the
// stream is consumed

// check with expected result, based on the order of completion
assert_eq!(first_4, vec![2, 4, 5, 3]);

Playground example


Edit: If you want to also get the index of completed the future, you can use this:

// create unordered collection of futures with indices
let futures = futures
        .into_iter()
        .enumerate()
        .map(|(i, fut)| fut.map(move |res| (i, res)))
        .collect::<FuturesUnordered<_>>()

Playground example

Upvotes: 6

kmdreko
kmdreko

Reputation: 59817

I don't believe there is anything built for this purpose. Perhaps you can do this with Streams or Channels, but the join_all implementation isn't too complicated. I've modified it so that it only waits for n results:

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::future::MaybeDone; // 0.3.15

fn iter_pin_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
    // Safety: `std` _could_ make this unsound if it were to decide Pin's
    // invariants aren't required to transmit through slices. Otherwise this has
    // the same safety as a normal field pin projection.
    //
    // Copied from `futures` implementation of `join_all`.
    unsafe { slice.get_unchecked_mut() }.iter_mut().map(|t| unsafe { Pin::new_unchecked(t) })
}

#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct JoinSome<F>
where
    F: Future,
{
    elems: Pin<Box<[MaybeDone<F>]>>,
    count: usize,
}

/// Will wait for at least `n` futures to complete. More may be returned if
/// multiple resolve around the same time.
///
/// # Panics
///
/// Will panic if iterator doesn't contain at least `n` futures.
pub fn join_some<I>(i: I, n: usize) -> JoinSome<I::Item>
where
    I: IntoIterator,
    I::Item: Future,
{
    let elems: Box<[_]> = i.into_iter().map(MaybeDone::Future).collect();
    assert!(elems.len() >= n);
    JoinSome { elems: elems.into(), count: n }
}

impl<F> Future for JoinSome<F>
where
    F: Future,
{
    type Output = Vec<F::Output>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut num_done = 0;

        for elem in iter_pin_mut(self.elems.as_mut()) {
            if !elem.poll(cx).is_pending() {
                num_done += 1;
            }
        }

        if num_done >= self.count {
            let mut elems = std::mem::replace(&mut self.elems, Box::pin([]));
            let result = iter_pin_mut(elems.as_mut()).filter_map(|e| e.take_output()).collect();
            Poll::Ready(result)
        } else {
            Poll::Pending
        }
    }
}

I've added documentation in the source that explains its behavior. You can see that it works with this simple test program:

#[tokio::main]
async fn main() {
    use std::time::{Instant, Duration};
    use tokio::time::sleep;

    let futures = vec![
        sleep(Duration::from_secs(1)),
        sleep(Duration::from_secs(2)),
        sleep(Duration::from_secs(3)),
        sleep(Duration::from_secs(4)),
        sleep(Duration::from_secs(5)),
    ];
    
    let now = Instant::now();
    let some = join_some(futures, 3).await;
    let elapsed = now.elapsed();
    
    println!("{} results in {:.2?}", some.len(), elapsed);
}
3 results in 3.00s seconds

See it working on the playground.

Upvotes: 2

Related Questions