Reputation: 25
I recently took a coding exam and faced a problem like the one below.
I solved it in a different language, but now I'm trying to solve this in Rust to improve my Rust coding skill...
You are given the number n
and an API endpoint.
Implement the following function f(n) and output the value of f(n)
f(0) = 1
f(2) = 2
f(n) = f(n-1)+f(n-2)+f(n-3) (when n %2 == 0)
f(n) = CallAPI(n) (when n%2 !=0)
CallAPI function returns a number between 0 to 100.
And CallAPI function always returns the same value if the input value n is the same.
CallAPI function should be invoked as few times as possible.
I thought it could be solved by implementing function memoization, so I wrote such codes.
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use async_recursion::async_recursion;
type Memo = Arc<Mutex<HashMap<usize, usize>>>;
#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
// In real situation number n would be passed as a input argument.
let n = "60";
let mut temp = HashMap::new();
temp.insert(0, 1);
temp.insert(2, 2);
let memo: Memo = Arc::new(Mutex::new(temp));
let ans = f(n.parse::<usize>().expect("error"), memo).await;
println!("{}", ans);
Ok(())
}
#[async_recursion]
async fn f(n: usize, memo: Memo) -> usize {
if let Some(&ans) = memo.lock().unwrap().get(&n) {
return ans;
}
if n % 2 == 0 {
let memo1 = memo.clone();
let memo2 = memo.clone();
let memo3 = memo.clone();
let (one, two, three) = tokio::join!(f(n - 1, memo1), f(n - 2, memo2), f(n - 3, memo3));
let ans = one + two + three;
let mut mut_memo = memo.lock().unwrap();
mut_memo.insert(n, ans);
ans
} else {
let n_str = n.to_string();
let ans = ask_server(&n_str).await;
let mut mut_memo = memo.lock().unwrap();
mut_memo.insert(n, ans);
ans
}
}
async fn ask_server(n: &String) -> usize {
// In real situation I sent HTTP GET Request(So this function needs to be async) and return a number that is contained in response, bot just return 100 for simplicity.
100
}
It seems to work, but its performance is not definitely good. (I'm sorry for my terrible codes...) I assume it is because it locks the "memo" variable every time. But I do not get to how to improve this code despite all my might...
I would like to ask those who know how to handle Rust well how it is efficiently implemented.
Upvotes: 1
Views: 204
Reputation: 8944
My first step was to look at the problem and figure out if we can re-write the input into a more convenient form.
f(0) = 1
f(1) = CallAPI(1)
f(2) = 2
f(3) = CallAPI(3)
f(4) = CallAPI(3) + 2 + CallAPI(1)
f(5) = CallAPI(5)
f(6) = CallAPI(5) + 2 * CallAPI(3) + 2 + CallAPI(1)
From here, the first pattern I am noticing is that we generally have 2 cases once n is over 3. If n is odd, then we can just default to using CallAPI
. However if n is even, our total becomes (CallAPI(n - 1) + CallAPI(n - 3)) + (CallAPI(n - 3) + CallAPI(n - 5)) + ... + 2
. As you can see, the only times we call CallAPI
with a previous value is on the call that directly follows. This is great because we can convert this to an iterative approach where multiple async calls can be grouped together. This lets us wait for far more futures concurrently. While we could attempt to do every CallAPI
concurrently, we don't want to use up all of the memory on the system and overload the API server so we can impose a limit where we only send out a limited number of requests at a given time.
After all of that, here is how I would write the solution. Keep in mind that async
Rust is not my specialty so there may be a shorter way this could be written.
use futures::prelude::*;
use futures::stream::FuturesUnordered;
// The maximum number of concurrent API calls to call concurrently in an
// unordered futures group.
const API_CALL_LIMIT: usize = 64;
#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
let answer = f(60).await;
println!("{}", answer);
Ok(())
}
async fn f(n: usize) -> usize {
match n {
0 => 1,
2 => 2,
x if x % 2 != 0 => perform_api_call(x).await,
x => {
let mut group = FuturesUnordered::new();
// Create iterator of values to put through the API
let mut search_values = (3..x - 1).step_by(2);
// Fill the unordered futures group up to the call limit
for y in (&mut search_values).take(API_CALL_LIMIT) {
group.push(perform_api_call(y));
}
let mut total = 0;
// Once call limit is reached, the remaining values to get from the
// API are added when a new position opens up.
for y in search_values {
// Wait for empty position
if let Some(result) = group.next().await {
total += 2 * result;
} else {
unreachable!("There must be remaining items if we are still pushing entries");
}
// Push new future to replace the one we just popped
group.push(perform_api_call(y));
}
// The first and last calls are special since they are not doubled
let first_call = perform_api_call(1);
let last_call = perform_api_call(x - 1);
// Wait for remaining entries in the group
while let Some(result) = group.next().await {
total += 2 * result;
}
// Add in the first and last entries along with f(2).
2 + total + first_call.await + last_call.await
}
}
}
async fn perform_api_call(_: usize) -> usize {
100
}
The difference is much more pronounced if we attempt to simulate the time it would have takes for the HTTP request to be performed.
use std::time::Duration;
async fn perform_api_call(_: usize) -> usize {
tokio::time::sleep(Duration::from_millis(500)).await;
100
}
Upvotes: 1