Manuel Weiß
Manuel Weiß

Reputation: 53

Grouping an Integer List into Partitions

Is there a easy way to do following in a stream:

public static void main(String[] args) {
    List<Integer> integerList = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
    System.out.print(partitioningValues(integerList, 3));
}

private static Map<Integer, List<Integer>> partitioningValues(List<Integer> integerList, int numberOfPartitions) {

    Map<Integer, List<Integer>> integerListMap = new HashMap<>();
    BigDecimal limit = BigDecimal.valueOf(integerList.size() / (double) numberOfPartitions);
    int limitRounded = limit.setScale(0, BigDecimal.ROUND_UP).intValue();

    for (int i = 0; i < numberOfPartitions; i++) {

        int toIndex = ((i + 1) * limitRounded) > integerList.size() ? integerList.size() : (i + 1) * limitRounded;
        integerListMap.put(i, integerList.subList(i * limitRounded, toIndex));
    }

    return integerListMap;
}

Result:

{0=[1, 2, 3, 4], 1=[5, 6, 7, 8], 2=[9, 10]}

Upvotes: 1

Views: 640

Answers (3)

Saravana
Saravana

Reputation: 12817

You can use groupingBy to split.

If the stream needs to be split by elements value

int split = 4;
Map<Integer, List<Integer>> map2 = integerList.stream().collect(Collectors.groupingBy(i -> (i-1) / split));
System.out.println(map2);

If the stream needs to be split by position

int[] pos = { -1 };
Map<Integer, List<Integer>> map = integerList.stream().peek(e -> pos[0]++).collect(Collectors.groupingBy(e -> pos[0] / split));
System.out.println(map);

output

{0=[1, 2, 3, 4], 1=[5, 6, 7, 8], 2=[9, 10]}

Upvotes: 1

Nick Vanderhoven
Nick Vanderhoven

Reputation: 3093

If you don't want to mutate a shared variable to keep track of the indexes and want to keep the stream parallellizable, you can still do so by using an alternative strategy.

The partition size is the maximum number of integers in a single partition. In all code snippets, let us define partitionSize as follows:

int partitionSize = (list.size() - 1) / partitions + 1;

where we use the concise -1/+1 notation for ceiling instead of Math.ceil.

A simple naive approach would be to find the index to group by:

list.stream().collect(groupingBy(i -> list.indexOf(i) / partitionSize));

But if you care about performance, you want to find a better way to handle the indexes.

An intuitive approach could be to first generate all the index positions, then iterate over them and collect the sublists. This would give you something like this, combining all the partitions in a List<List<Integer>>:

int[] indexes = IntStream.iterate(0, i -> i + partitionSize).limit(partitions+1).toArray();

IntStream.range(0, indexes.length - 1)
         .mapToObj(i -> list.subList(indexes[i], Math.min(indexes[i + 1], list.size())))
         .collect(toList());

Where Math.min is used to find the correct end boundaries of the interval in case we approach the end of the list.

You can however combine the index calculation and the looping as follows:

  IntStream.rangeClosed(0, list.size() / partitionSize)
           .mapToObj(i -> list.subList(i * partitionSize, Math.min((i+1) * partitionSize, list.size())))
           .collect(toList());

Remark that the result is a List<List<Integer>> where every list index maps on the sublist of the partition.

If you really want a map with keys 0,1,2,... you could collect to a Map instead:

Map<Integer, List<List<Integer>>> result =
      IntStream.rangeClosed(0, list.size() / partitionSize)
               .mapToObj(i -> list.subList(i * partitionSize, Math.min((i + 1) * partitionSize, list.size())))
               .collect(Collectors.groupingBy(l -> l.get(0) / partitionSize));

Or, if you don't mind using external libraries, e.g. Guava has

Lists.partition(integerList, 3);

Example.

List<Integer> list = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

int partitions = 4;
int partitionSize = (list.size() - 1) / partitions + 1; //ceil

List<List<Integer>> result = IntStream.rangeClosed(0, list.size() / partitionSize)
                                      .mapToObj(i -> list.subList(i * partitionSize, Math.min((i+1) * partitionSize, list.size())))
                                      .collect(toList());

System.out.println(result);

Result: [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]

Upvotes: 0

user6904265
user6904265

Reputation: 1938

I propose you this approach: it iterates from 0 to numberOfPartitions, at each step it creates a sublist of batchLength elements (only the last step could have less then batchLength elements) and collects the sublists in a HashMap where the key is the current step and the value is the sublist at current step.

public static Map<Integer, List<Integer>> partitioningValues(List<Integer> integerList, int numberOfPartitions) {
    int size = integerList.size();
    BigDecimal limit = BigDecimal.valueOf(size / (double) numberOfPartitions);
    int batchLength =  limit.setScale(0, BigDecimal.ROUND_UP).intValue();
    AtomicInteger step = new AtomicInteger();
    return IntStream.range(0, numberOfPartitions)
            .boxed()
              .collect(
                Collectors.toMap(
                   s -> step.getAndIncrement(), 
                   s -> integerList.subList(s * batchLength, Math.min((s+1)*batchLength, size)))
              );
}

Grouping by version (very similar to second solution of @Saravana):

...
AtomicInteger pos =  new AtomicInteger(0);
AtomicInteger split = new AtomicInteger(batchLength);
Map<Integer, List<Integer>> map = integerList.stream()
        .collect(Collectors.groupingBy(e -> Integer.valueOf(pos.getAndIncrement() / split.get())));

Upvotes: 0

Related Questions