Ach113
Ach113

Reputation: 1825

Rust - how to find n-th most frequent element in a collection

I can't imagine this hasn't been asked before, but I have searched everywhere and could not find the answer.

I have an iterable, which contains duplicate elements. I want to count number of times each element occurs in this iterable and return n-th most frequent one.

I have a working code which does exactly that, but I really doubt its the most optimal way to achieve this.

use std::collections::{BinaryHeap, HashMap};

// returns n-th most frequent element in collection
pub fn most_frequent<T: std::hash::Hash + std::cmp::Eq + std::cmp::Ord>(array: &[T], n: u32) -> &T {
    // intialize empty hashmap
    let mut map = HashMap::new();

    // count occurence of each element in iterable and save as (value,count) in hashmap
    for value in array {
        // taken from https://doc.rust-lang.org/std/collections/struct.HashMap.html#method.entry
        // not exactly sure how this works
        let counter = map.entry(value).or_insert(0);
        *counter += 1;
    }

    // determine highest frequency of some element in the collection
    let mut heap: BinaryHeap<_> = map.values().collect();
    let mut max = heap.pop().unwrap();
    // get n-th largest value
    for _i in 1..n {
        max = heap.pop().unwrap();
    }

    // find that element (get key from value in hashmap)
    // taken from https://stackoverflow.com/questions/59401720/how-do-i-find-the-key-for-a-value-in-a-hashmap
    map.iter()
        .find_map(|(key, &val)| if val == *max { Some(key) } else { None })
        .unwrap()
}

Are there any better ways or more optimal std methods to achieve what I want? Or maybe there are some community made crates that I could use.

Upvotes: 5

Views: 3866

Answers (1)

Sven Marnach
Sven Marnach

Reputation: 602305

Your implementation has a time complexity of Ω(n log n), where n is the length of the array. The optimal solution to this problem has a complexity of Ω(n log k) for retrieving the k-th most frequent element. The usual implementation of this optimal solution indeed involves a binary heap, but not in the way you used it.

Here's a suggested implementation of the common algorithm:

use std::cmp::{Eq, Ord, Reverse};
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;

pub fn most_frequent<T>(array: &[T], k: usize) -> Vec<(usize, &T)>
where
    T: Hash + Eq + Ord,
{
    let mut map = HashMap::new();
    for x in array {
        *map.entry(x).or_default() += 1;
    }

    let mut heap = BinaryHeap::with_capacity(k + 1);
    for (x, count) in map.into_iter() {
        heap.push(Reverse((count, x)));
        if heap.len() > k {
            heap.pop();
        }
    }
    heap.into_sorted_vec().into_iter().map(|r| r.0).collect()
}

(Playground)

I changed the prototype of the function to return a vector of the k most frequent elements together with their counts, since this is what you need to keep track of anyway. If you only want the k-th most frequent element, you can index the result with [k - 1][1].

The algorithm itself first builds a map of element counts the same way your code does – I just wrote it in a more concise form.

Next, we buid a BinaryHeap for the most frequent elements. After each iteration, this heap contains at most k elements, which are the most frequent ones seen so far. If there are more than k elements in the heap, we drop the least frequent element. Since we always drop the least frequent element seen so far, the heap always retains the k most frequent elements seen so far. We need to use the Reverse wrapper to get a min heap, as documented in the documentation of BinaryHeap.

Finally, we collect the results into a vector. The into_sorted_vec() function basically does this job for us, but we still want to unwrap the items from its Reverse wrapper – that wrapper is an implemenetation detail of our function and should not be returned to the caller.

(In Rust Nightly, we could also use the into_iter_sorted() method, saving one vector allocation.)

The code in this answer makes sure the heap is essentially limited to k elements, so an insertion to the heap has a complexity of Ω(log k). In your code, you push all elements from the array to the heap at once, without limiting the size of the heap, so you end up with a complexity of Ω(log n) for insertions. You essentially use the binary heap to sort a list of counts. Which works, but it's certainly neither the easiest nor the fastest way to achieve that, so there is little justification for going that route.

Upvotes: 2

Related Questions