Dimitri_3gg
Dimitri_3gg

Reputation: 11

Borrow error when attempting recursion on HashMap, where each value needs a reference to the map

I'm currently have an issue regarding Rust's borrowing policy. I have a HashMap of structs 'Value', each which contains a list of keys to other Values in HashMap. I am attempting to recursively call a function on these Values which requires a reference to the HashMap.

use std::collections::HashMap;

struct Value {
    val: f64,
    prevs: Vec<usize>,
    sum: f64,
}
impl Value {
    pub fn new(i: usize) -> Value {
        let mut res = Value {
            val: 0.1,
            prevs: Vec::new(),
            sum: 0.0,
        };
        for j in 0..i {
            res.prevs.push(j);
        }
        res
    }

    pub fn evaluate(&mut self, pool: &mut HashMap<usize, Value>) -> f64 {
        self.sum = self.val;
        for i in &self.prevs {
            let prev = pool.get_mut(i).unwrap();
            self.sum += prev.evaluate(pool);
        }
        self.sum
    }
}

fn main() {
    let mut hm: HashMap<usize, Value> = HashMap::new();
    for i in 0..10 {
        hm.insert(i, Value::new(i));
    }
    println!("{}", hm.get(&9).unwrap().evaluate(&mut hm));
}

Error:

error[E0499]: cannot borrow `*pool` as mutable more than once at a time
  --> src/lib.rs:25:39
   |
24 |             let prev = pool.get_mut(i).unwrap();
   |                        --------------- first mutable borrow occurs here
25 |             self.sum += prev.evaluate(pool);
   |                              -------- ^^^^ second mutable borrow occurs here
   |                              |
   |                              first borrow later used by call

Playground

Context

I'm attempting to calculate the output of a neural network (usually done via feedforward) by starting from the output, and recursively evaluating each node, as a weighted sum of the nodes connected to it, with an unpredictable topology. This requires each node having a list of input_nodes, which are keys to a node pool HashMap.

Upvotes: 1

Views: 389

Answers (1)

Mika Vatanen
Mika Vatanen

Reputation: 4007

Below is a sample with a few variants:

  1. Non-performant and probably deadlock-prone but compiling version using Arc<Mutex>
  2. High-performance version using Vec and split_at_mut
  3. Highly unsafe, UB and "against-all-good-practices" version using Vec and pointers. At least evaluates to the same number, wanted to add for performance comparison.
#![feature(test)]
extern crate test;

use std::{collections::HashMap, sync::{Arc, Mutex}};

#[derive(Debug)]
struct Value {
    val: f64,
    prevs: Vec<usize>,
    sum: f64,
}
impl Value {
    pub fn new(i: usize) -> Value {
        let mut res = Value {
            val: 0.1,
            prevs: Vec::new(),
            sum: 0.0,
        };
        for j in 0..i {
            res.prevs.push(j);
        }
        res
    }

    pub fn evaluate(&mut self, pool: &mut HashMap<usize, Arc<Mutex<Value>>>) -> f64 {
        self.sum = self.val;
        for i in &self.prevs {
            let val = pool.get_mut(i).unwrap().clone();
            self.sum += val.lock().unwrap().evaluate(pool);

        }
        self.sum
    }

    pub fn evaluate_split(&mut self, pool: &mut [Value]) -> f64 {
        self.sum = self.val;
        for i in &self.prevs {
            let (hm, val) = pool.split_at_mut(*i);
            self.sum += val[0].evaluate_split(hm);
        }
        self.sum
    }

    // OBS! Don't do this, horribly unsafe and wrong
    pub unsafe fn evaluate_unsafe(&mut self, pool: *const Value, pool_len: usize) -> f64 {
        let pool = std::slice::from_raw_parts_mut(pool as *mut Value, pool_len);

        self.sum = self.val;
        for i in &self.prevs {
            let (pool_ptr, pool_len) = (pool.as_ptr(), pool.len());
            self.sum += pool[*i].evaluate_unsafe(pool_ptr, pool_len);
        }
        self.sum
    }
}

fn main() {
    // arcmutex
    let mut hm: HashMap<usize, Arc<Mutex<Value>>> = HashMap::new();
    for i in 0..10 {
        hm.insert(i, Arc::new(Mutex::new(Value::new(i))));
    }
    
    let val = hm.get(&9).unwrap().clone();
    assert_eq!(val.lock().unwrap().evaluate(&mut hm), 51.2);

    // split vec
    let mut hm = (0..10).map(|v| {
        Value::new(v)
    }).collect::<Vec<_>>();

    let (hm, val) = hm.split_at_mut(9);
    assert_eq!((hm.len(), val.len()), (9, 1));
    assert_eq!(val[0].evaluate_split(hm), 51.2);

}


#[cfg(test)]
mod tests {
    use test::bench;
    use super::*;

    #[bench]
    fn bench_arc_mutex(b: &mut bench::Bencher) {
        let mut hm: HashMap<usize, Arc<Mutex<Value>>> = HashMap::new();
        for i in 0..10 {
            hm.insert(i, Arc::new(Mutex::new(Value::new(i))));
        }
        
        
        b.iter(|| {
            let val = hm.get(&9).unwrap().clone();
            assert_eq!(val.lock().unwrap().evaluate(&mut hm), 51.2);
        });
    }

    #[bench]
    fn bench_split(b: &mut bench::Bencher) {
        let mut hm = (0..10).map(|v| {
            Value::new(v)
        }).collect::<Vec<_>>();
        
        b.iter(|| {
            let (hm, val) = hm.split_at_mut(9);
            assert_eq!(val[0].evaluate_split(hm), 51.2);
        });
    }

    #[bench]
    fn bench_unsafe(b: &mut bench::Bencher) {
        let mut hm = (0..10).map(|v| {
            Value::new(v)
        }).collect::<Vec<_>>();
        
        b.iter(|| {
            // OBS! Don't do this, horribly unsafe and wrong
            let (hm_ptr, hm_len) = (hm.as_ptr(), hm.len());
            let val = &mut hm[9];
            assert_eq!(unsafe { val.evaluate_unsafe(hm_ptr, hm_len) }, 51.2);
        });
    }

}

cargo bench results to:

running 3 tests
test tests::bench_arc_mutex ... bench:      13,249 ns/iter (+/- 367)
test tests::bench_split     ... bench:       1,974 ns/iter (+/- 70)
test tests::bench_unsafe    ... bench:       1,989 ns/iter (+/- 62)

Also, have a look at https://rust-unofficial.github.io/too-many-lists/index.html

Upvotes: 1

Related Questions