benroth
benroth

Reputation: 2618

Java: top-n elements from stream source

Assume you read data items and associated scores from a "stream" source (i.e. no random access or multiple passes possible).

What is the best way of keeping, at any time, only those elements in memory with lowest weight encountered so far. I would be interested in the "Java" way of doing it, the shorter the idiom the better, rather than algorithm ("use search-tree, insert new element, delete biggest if size exceeded").

Below is the solution I came up with, however I find it a bit lengthy, also some behaviour might be unexpected (the same item with different scores is possibly kept multiple times, while the same item added with the same score is kept only once). I also feel there should be something existing for this.

import java.util.AbstractMap.SimpleEntry;
import java.util.Map.Entry;
import java.util.Comparator;
import java.util.TreeSet;

/**
 * Stores the n smallest (by score) elements only.
 */
public class TopN<T extends Comparable<T>> {
  private TreeSet<Entry<T, Double>> elements;
  private int n;

  public TopN(int n) {
    this.n = n;
    this.elements = new TreeSet<Entry<T, Double>>(
        new Comparator<Entry<T, Double>>() {
          @Override
          public int compare(Entry<T, Double> o1, Entry<T, Double> o2) {
            if (o1.getValue() > o2.getValue()) return 1;
            if (o1.getValue() < o2.getValue()) return -1;
            return o1.getKey() == null ? 1 : o1.getKey().compareTo(o2.getKey());
          }
    });
  }

  /**
   * Adds the element if the score is lower than the n-th smallest score.
   */
  public void add(T element, double score) {
    Entry<T, Double> keyVal = new SimpleEntry<T, Double>(element,score);
    elements.add(keyVal);
    if (elements.size() > n) {
      elements.pollLast();
    }
  }

  /**
   * Returns the elements with n smallest scores.
   */
  public TreeSet<Entry<T, Double>> get() {
    return elements;
  }
}

There is a similar question, but it doesn't include the stream source / memory requirement: Find top N elements in an Array

Upvotes: 2

Views: 4971

Answers (2)

Pritesh Mhatre
Pritesh Mhatre

Reputation: 4065

You can guava's Comparators class to get the desired results. Please see a sample below, which gets top 5 numbers. Api can be found here.

import java.util.Comparator;
import java.util.List;
import java.util.stream.Collector;

import org.junit.Test;

import com.google.common.collect.Comparators;
import com.google.common.collect.Lists;

public class TestComparator {

    @Test
    public void testTopN() {
        final List<Integer> numbers = Lists.newArrayList(1, 3, 8, 2, 6, 4, 7, 5, 9, 0);
        final Collector<Integer, ?, List<Integer>> collector = Comparators.greatest(5,
                Comparator.<Integer>naturalOrder());
        final List<Integer> top = numbers.stream().collect(collector);
        System.out.println(top);
    }

}

Output: [9, 8, 7, 6, 5]

Upvotes: 1

dty
dty

Reputation: 18998

Use a "heap" datastructure. Java has a built in one: PriorityQueue. Simply define your comparator for "best", and feed all your data from the stream into the priority queue.

EDIT:

To add a bit more colour to this answer, you probably need to do something like this:

  • Define a comparator that works the opposite way to what you want (i.e. favours the items you want to throw away) - or define one that works the right way, and wrap it with Collections.reverseOrder(...)
  • Iterate over your data and put each element into the pqueue.
  • With each insert, if the size of the pqueue is >n, use poll() to remove the "top" element from the heap - which, because of your comparator, will actually be the "worst" one.

What you're left with is a pqueue with n elements in which were the "least bad".

Upvotes: 6

Related Questions