Reputation: 11
I'm currently have an issue regarding Rust's borrowing policy. I have a HashMap of structs 'Value', each which contains a list of keys to other Values in HashMap. I am attempting to recursively call a function on these Values which requires a reference to the HashMap.
use std::collections::HashMap;
struct Value {
val: f64,
prevs: Vec<usize>,
sum: f64,
}
impl Value {
pub fn new(i: usize) -> Value {
let mut res = Value {
val: 0.1,
prevs: Vec::new(),
sum: 0.0,
};
for j in 0..i {
res.prevs.push(j);
}
res
}
pub fn evaluate(&mut self, pool: &mut HashMap<usize, Value>) -> f64 {
self.sum = self.val;
for i in &self.prevs {
let prev = pool.get_mut(i).unwrap();
self.sum += prev.evaluate(pool);
}
self.sum
}
}
fn main() {
let mut hm: HashMap<usize, Value> = HashMap::new();
for i in 0..10 {
hm.insert(i, Value::new(i));
}
println!("{}", hm.get(&9).unwrap().evaluate(&mut hm));
}
Error:
error[E0499]: cannot borrow `*pool` as mutable more than once at a time
--> src/lib.rs:25:39
|
24 | let prev = pool.get_mut(i).unwrap();
| --------------- first mutable borrow occurs here
25 | self.sum += prev.evaluate(pool);
| -------- ^^^^ second mutable borrow occurs here
| |
| first borrow later used by call
I'm attempting to calculate the output of a neural network (usually done via feedforward) by starting from the output, and recursively evaluating each node, as a weighted sum of the nodes connected to it, with an unpredictable topology. This requires each node having a list of input_nodes, which are keys to a node pool HashMap.
Upvotes: 1
Views: 389
Reputation: 4007
Below is a sample with a few variants:
Arc<Mutex>
Vec
and split_at_mut
Vec
and pointers. At least evaluates to the same number, wanted to add for performance comparison.#![feature(test)]
extern crate test;
use std::{collections::HashMap, sync::{Arc, Mutex}};
#[derive(Debug)]
struct Value {
val: f64,
prevs: Vec<usize>,
sum: f64,
}
impl Value {
pub fn new(i: usize) -> Value {
let mut res = Value {
val: 0.1,
prevs: Vec::new(),
sum: 0.0,
};
for j in 0..i {
res.prevs.push(j);
}
res
}
pub fn evaluate(&mut self, pool: &mut HashMap<usize, Arc<Mutex<Value>>>) -> f64 {
self.sum = self.val;
for i in &self.prevs {
let val = pool.get_mut(i).unwrap().clone();
self.sum += val.lock().unwrap().evaluate(pool);
}
self.sum
}
pub fn evaluate_split(&mut self, pool: &mut [Value]) -> f64 {
self.sum = self.val;
for i in &self.prevs {
let (hm, val) = pool.split_at_mut(*i);
self.sum += val[0].evaluate_split(hm);
}
self.sum
}
// OBS! Don't do this, horribly unsafe and wrong
pub unsafe fn evaluate_unsafe(&mut self, pool: *const Value, pool_len: usize) -> f64 {
let pool = std::slice::from_raw_parts_mut(pool as *mut Value, pool_len);
self.sum = self.val;
for i in &self.prevs {
let (pool_ptr, pool_len) = (pool.as_ptr(), pool.len());
self.sum += pool[*i].evaluate_unsafe(pool_ptr, pool_len);
}
self.sum
}
}
fn main() {
// arcmutex
let mut hm: HashMap<usize, Arc<Mutex<Value>>> = HashMap::new();
for i in 0..10 {
hm.insert(i, Arc::new(Mutex::new(Value::new(i))));
}
let val = hm.get(&9).unwrap().clone();
assert_eq!(val.lock().unwrap().evaluate(&mut hm), 51.2);
// split vec
let mut hm = (0..10).map(|v| {
Value::new(v)
}).collect::<Vec<_>>();
let (hm, val) = hm.split_at_mut(9);
assert_eq!((hm.len(), val.len()), (9, 1));
assert_eq!(val[0].evaluate_split(hm), 51.2);
}
#[cfg(test)]
mod tests {
use test::bench;
use super::*;
#[bench]
fn bench_arc_mutex(b: &mut bench::Bencher) {
let mut hm: HashMap<usize, Arc<Mutex<Value>>> = HashMap::new();
for i in 0..10 {
hm.insert(i, Arc::new(Mutex::new(Value::new(i))));
}
b.iter(|| {
let val = hm.get(&9).unwrap().clone();
assert_eq!(val.lock().unwrap().evaluate(&mut hm), 51.2);
});
}
#[bench]
fn bench_split(b: &mut bench::Bencher) {
let mut hm = (0..10).map(|v| {
Value::new(v)
}).collect::<Vec<_>>();
b.iter(|| {
let (hm, val) = hm.split_at_mut(9);
assert_eq!(val[0].evaluate_split(hm), 51.2);
});
}
#[bench]
fn bench_unsafe(b: &mut bench::Bencher) {
let mut hm = (0..10).map(|v| {
Value::new(v)
}).collect::<Vec<_>>();
b.iter(|| {
// OBS! Don't do this, horribly unsafe and wrong
let (hm_ptr, hm_len) = (hm.as_ptr(), hm.len());
let val = &mut hm[9];
assert_eq!(unsafe { val.evaluate_unsafe(hm_ptr, hm_len) }, 51.2);
});
}
}
cargo bench
results to:
running 3 tests
test tests::bench_arc_mutex ... bench: 13,249 ns/iter (+/- 367)
test tests::bench_split ... bench: 1,974 ns/iter (+/- 70)
test tests::bench_unsafe ... bench: 1,989 ns/iter (+/- 62)
Also, have a look at https://rust-unofficial.github.io/too-many-lists/index.html
Upvotes: 1