S0i
S0i

Reputation: 25

The Efficient way of Recursion API call in Rust using tokio

Problem

I recently took a coding exam and faced a problem like the one below.

I solved it in a different language, but now I'm trying to solve this in Rust to improve my Rust coding skill...


You are given the number n and an API endpoint.

Implement the following function f(n) and output the value of f(n)

f(0) = 1
f(2) = 2
f(n) = f(n-1)+f(n-2)+f(n-3)  (when n %2 == 0)
f(n) = CallAPI(n)    (when n%2 !=0)

My Solutions

I thought it could be solved by implementing function memoization, so I wrote such codes.

use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
};

use async_recursion::async_recursion;

type Memo = Arc<Mutex<HashMap<usize, usize>>>;

#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // In real situation number n would be passed as a input argument.
    let n = "60";
    let mut temp = HashMap::new();
    temp.insert(0, 1);
    temp.insert(2, 2);
    let memo: Memo = Arc::new(Mutex::new(temp));
    let ans = f(n.parse::<usize>().expect("error"), memo).await;
    println!("{}", ans);
    Ok(())
}

#[async_recursion]
async fn f(n: usize, memo: Memo) -> usize {
    if let Some(&ans) = memo.lock().unwrap().get(&n) {
        return ans;
    }
    if n % 2 == 0 {
        let memo1 = memo.clone();
        let memo2 = memo.clone();
        let memo3 = memo.clone();
        let (one, two, three) = tokio::join!(f(n - 1, memo1), f(n - 2, memo2), f(n - 3, memo3));
        let ans = one + two + three;
        let mut mut_memo = memo.lock().unwrap();
        mut_memo.insert(n, ans);
        ans
    } else {
        let n_str = n.to_string();
        let ans = ask_server(&n_str).await;
        let mut mut_memo = memo.lock().unwrap();
        mut_memo.insert(n, ans);
        ans
    }
}

async fn ask_server(n: &String) -> usize {
    // In real situation I sent HTTP GET Request(So this function needs to be async) and return a number that is contained in response, bot just return 100 for simplicity.
    100
}

It seems to work, but its performance is not definitely good. (I'm sorry for my terrible codes...) I assume it is because it locks the "memo" variable every time. But I do not get to how to improve this code despite all my might...

I would like to ask those who know how to handle Rust well how it is efficiently implemented.

Upvotes: 1

Views: 204

Answers (1)

Locke
Locke

Reputation: 8944

My first step was to look at the problem and figure out if we can re-write the input into a more convenient form.

f(0) = 1
f(1) = CallAPI(1)
f(2) = 2
f(3) = CallAPI(3)
f(4) = CallAPI(3) + 2 + CallAPI(1)
f(5) = CallAPI(5)
f(6) = CallAPI(5) + 2 * CallAPI(3) + 2 + CallAPI(1)

From here, the first pattern I am noticing is that we generally have 2 cases once n is over 3. If n is odd, then we can just default to using CallAPI. However if n is even, our total becomes (CallAPI(n - 1) + CallAPI(n - 3)) + (CallAPI(n - 3) + CallAPI(n - 5)) + ... + 2. As you can see, the only times we call CallAPI with a previous value is on the call that directly follows. This is great because we can convert this to an iterative approach where multiple async calls can be grouped together. This lets us wait for far more futures concurrently. While we could attempt to do every CallAPI concurrently, we don't want to use up all of the memory on the system and overload the API server so we can impose a limit where we only send out a limited number of requests at a given time.

After all of that, here is how I would write the solution. Keep in mind that async Rust is not my specialty so there may be a shorter way this could be written.

use futures::prelude::*;
use futures::stream::FuturesUnordered;

// The maximum number of concurrent API calls to call concurrently in an
// unordered futures group.
const API_CALL_LIMIT: usize = 64;

#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let answer = f(60).await;
    println!("{}", answer);
    Ok(())
}

async fn f(n: usize) -> usize {
    match n {
        0 => 1,
        2 => 2,
        x if x % 2 != 0 => perform_api_call(x).await,
        x => {
            let mut group = FuturesUnordered::new();

            // Create iterator of values to put through the API
            let mut search_values = (3..x - 1).step_by(2);

            // Fill the unordered futures group up to the call limit
            for y in (&mut search_values).take(API_CALL_LIMIT) {
                group.push(perform_api_call(y));
            }

            let mut total = 0;

            // Once call limit is reached, the remaining values to get from the
            // API are added when a new position opens up.
            for y in search_values {
                // Wait for empty position
                if let Some(result) = group.next().await {
                    total += 2 * result;
                } else {
                    unreachable!("There must be remaining items if we are still pushing entries");
                }

                // Push new future to replace the one we just popped
                group.push(perform_api_call(y));
            }

            // The first and last calls are special since they are not doubled
            let first_call = perform_api_call(1);
            let last_call = perform_api_call(x - 1);

            // Wait for remaining entries in the group
            while let Some(result) = group.next().await {
                total += 2 * result;
            }

            // Add in the first and last entries along with f(2).
            2 + total + first_call.await + last_call.await
        }
    }
}

async fn perform_api_call(_: usize) -> usize {
    100
}

Rust Playground

The difference is much more pronounced if we attempt to simulate the time it would have takes for the HTTP request to be performed.

use std::time::Duration;
async fn perform_api_call(_: usize) -> usize {
    tokio::time::sleep(Duration::from_millis(500)).await;
    100
}

Upvotes: 1

Related Questions