Reputation: 11308
How can we make a reduce on a stream so that it short-circuits when it encounters an absorbing element for the reducing operation?
The typical mathematical case would be 0 for multiplication. This Stream
:
int product = IntStream.of(2, 3, 4, 5, 0, 7, 8)
.reduce(1, (a, b) -> a * b);
will consume the last two elements (7
and 8
) regardless of the fact that once 0
has been encountered the product is known.
Note: This is essentially the same question as How to short-circuit a reduction of boolean values combined using || on Stream?. However, since that question focuses on a stream of boolean values, and its answer cannot be generalized for other types and reduce operations, I'd like to ask the more general question.
Upvotes: 31
Views: 3889
Reputation: 21256
Here's a solution that uses a custom short-circuiting Gatherer
in the upcoming Java 24 to filter the stream, keeping all elements until the absorbing element 0
or the end of the stream is reached. A multiplication reduction is then performed on this filtered stream produced by the gathering operation (including the absorbing element, if found).
IntStream stream = IntStream.of(2, 3, 4, 5, 0, 7, 8);
int product = stream
.boxed()
.gather(Gatherer.<Integer, Integer>of((_, element, downstream) ->
downstream.push(element) && element != 0))
.reduce(1, (a, b) -> a * b);
If the stream is sequential, the gatherer is basically performing a takeWhile
operation, but it also keeps the first non-matching element.
This will only short-circuit if the absorbing element is in the stream, and not if it was produced by a reduction. For instance, in Java 1874919424 * 1874919424 == 0
due to overflow, but this implementation will not short circuit if the product is 0 due to these two values being in the stream.
Upvotes: 0
Reputation: 21256
This can be accomplished using a custom short-circuiting Gatherer
in the upcoming Java 24.
IntStream stream = IntStream.of(2, 3, 4, 5, 0, 7, 8);
class State { int total = 1; }
int product = stream
.boxed()
.gather(Gatherer.<Integer, State, Integer>ofSequential(
State::new,
(state, element, downstream) -> {
state.total *= element;
return state.total != 0;
},
(state, downstream) -> downstream.push(state.total)
))
.findFirst()
.orElseThrow();
This gatherer converts the stream to a single-element stream containing the product of all elements in the stream.
The gatherer consumes elements from the stream and multiplies them. It indicates that it needs no more elements — i.e. short-circuits — if the product is ever 0 in the integrator.
The gatherer can be made parallel, though the short-circuiting behavior would apply per parallel piece of work. The overall processing would only finish once all pieces of work complete, either through short-circuiting at 0 or processing all its elements.
int product = stream
.boxed()
.parallel()
.gather(Gatherer.<Integer, State, Integer>of(
State::new,
(state, element, downstream) -> {
state.total *= element;
return state.total != 0;
},
(s1, s2) -> {
s1.total *= s2.total;
return s1;
},
(state, downstream) -> downstream.push(state.total)
))
.findFirst()
.orElseThrow();
This can be made into a general-purpose gatherer that can support any reduction operation and absorbing element, i.e. a short-circuiting equivalent to Gatherers.fold
:
public static <T, R> Gatherer<T, ?, R> fold(
R initial, R absorbingElement, BiFunction<? super R, ? super T, ? extends R> folder) {
class State { R folded = initial; }
return Gatherer.ofSequential(
State::new,
(state, element, downstream) -> {
state.folded = folder.apply(state.folded, element);
return !Objects.equals(state.folded, absorbingElement);
},
(state, downstream) -> downstream.push(state.folded)
);
}
int product = Stream.of(2, 3, 4, 5, 0, 7, 8)
.gather(fold(1, 0, (a, b) -> a * b))
.findFirst()
.orElseThrow();
int min = Stream.of(1, 6, Integer.MIN_VALUE, 5, 6)
.gather(fold(Integer.MAX_VALUE, Integer.MIN_VALUE, Integer::min))
.findFirst()
.orElseThrow();
Upvotes: 1
Reputation: 3758
this is how it is done after the introduction of takeWhile
since Java 9
int[] last = {1};
int product = IntStream.of(2, 3, 4, 5, 0, 7, 8)
.takeWhile(i -> last[0] != 0).reduce(1, (a, b) -> (last[0] = a) * b);
Upvotes: 1
Reputation: 100279
Unfortunately the Stream API has limited capabilities to create your own short-circuit operations. Not so clean solution would be to throw a RuntimeException
and catch it. Here's the implementation for IntStream
, but it can be generalized for other stream types as well:
public static int reduceWithCancelEx(IntStream stream, int identity,
IntBinaryOperator combiner, IntPredicate cancelCondition) {
class CancelException extends RuntimeException {
private final int val;
CancelException(int val) {
this.val = val;
}
}
try {
return stream.reduce(identity, (a, b) -> {
int res = combiner.applyAsInt(a, b);
if(cancelCondition.test(res))
throw new CancelException(res);
return res;
});
} catch (CancelException e) {
return e.val;
}
}
Usage example:
int product = reduceWithCancelEx(
IntStream.of(2, 3, 4, 5, 0, 7, 8).peek(System.out::println),
1, (a, b) -> a * b, val -> val == 0);
System.out.println("Result: "+product);
Output:
2
3
4
5
0
Result: 0
Note that even though it works with parallel streams, it's not guaranteed that other parallel tasks will be finished as soon as one of them throws an exception. The sub-tasks which are already started will likely to run till finish, so you may process more elements than expected.
Update: alternative solution which is much longer, but more parallel-friendly. It's based on custom spliterator which returns at most one element which is result of accumulation of all underlying elements). When you use it in sequential mode, it does all the work in single tryAdvance
call. When you split it, each part generates the correspoding single partial result, which are reduced by Stream engine using the combiner function. Here's generic version, but primitive specialization is possible as well.
final static class CancellableReduceSpliterator<T, A> implements Spliterator<A>,
Consumer<T>, Cloneable {
private Spliterator<T> source;
private final BiFunction<A, ? super T, A> accumulator;
private final Predicate<A> cancelPredicate;
private final AtomicBoolean cancelled = new AtomicBoolean();
private A acc;
CancellableReduceSpliterator(Spliterator<T> source, A identity,
BiFunction<A, ? super T, A> accumulator, Predicate<A> cancelPredicate) {
this.source = source;
this.acc = identity;
this.accumulator = accumulator;
this.cancelPredicate = cancelPredicate;
}
@Override
public boolean tryAdvance(Consumer<? super A> action) {
if (source == null || cancelled.get()) {
source = null;
return false;
}
while (!cancelled.get() && source.tryAdvance(this)) {
if (cancelPredicate.test(acc)) {
cancelled.set(true);
break;
}
}
source = null;
action.accept(acc);
return true;
}
@Override
public void forEachRemaining(Consumer<? super A> action) {
tryAdvance(action);
}
@Override
public Spliterator<A> trySplit() {
if(source == null || cancelled.get()) {
source = null;
return null;
}
Spliterator<T> prefix = source.trySplit();
if (prefix == null)
return null;
try {
@SuppressWarnings("unchecked")
CancellableReduceSpliterator<T, A> result =
(CancellableReduceSpliterator<T, A>) this.clone();
result.source = prefix;
return result;
} catch (CloneNotSupportedException e) {
throw new InternalError();
}
}
@Override
public long estimateSize() {
// let's pretend we have the same number of elements
// as the source, so the pipeline engine parallelize it in the same way
return source == null ? 0 : source.estimateSize();
}
@Override
public int characteristics() {
return source == null ? SIZED : source.characteristics() & ORDERED;
}
@Override
public void accept(T t) {
this.acc = accumulator.apply(this.acc, t);
}
}
Methods which are analogous to Stream.reduce(identity, accumulator, combiner)
and Stream.reduce(identity, combiner)
, but with cancelPredicate
:
public static <T, U> U reduceWithCancel(Stream<T> stream, U identity,
BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner,
Predicate<U> cancelPredicate) {
return StreamSupport
.stream(new CancellableReduceSpliterator<>(stream.spliterator(), identity,
accumulator, cancelPredicate), stream.isParallel()).reduce(combiner)
.orElse(identity);
}
public static <T> T reduceWithCancel(Stream<T> stream, T identity,
BinaryOperator<T> combiner, Predicate<T> cancelPredicate) {
return reduceWithCancel(stream, identity, combiner, combiner, cancelPredicate);
}
Let's test both versions and count how many elements are actually processed. Let's put the 0
close to end. Exception version:
AtomicInteger count = new AtomicInteger();
int product = reduceWithCancelEx(
IntStream.range(-1000000, 100).filter(x -> x == 0 || x % 2 != 0)
.parallel().peek(i -> count.incrementAndGet()), 1,
(a, b) -> a * b, x -> x == 0);
System.out.println("product: " + product + "/count: " + count);
Thread.sleep(1000);
System.out.println("product: " + product + "/count: " + count);
Typical output:
product: 0/count: 281721
product: 0/count: 500001
So while result is returned when only some elements are processed, the tasks continue working in background and counter is still increasing. Here's spliterator version:
AtomicInteger count = new AtomicInteger();
int product = reduceWithCancel(
IntStream.range(-1000000, 100).filter(x -> x == 0 || x % 2 != 0)
.parallel().peek(i -> count.incrementAndGet()).boxed(),
1, (a, b) -> a * b, x -> x == 0);
System.out.println("product: " + product + "/count: " + count);
Thread.sleep(1000);
System.out.println("product: " + product + "/count: " + count);
Typical output:
product: 0/count: 281353
product: 0/count: 281353
All the tasks are actually finished when the result is returned.
Upvotes: 13
Reputation: 11308
My own take at this is to not use reduce()
per se, but use an existing short-circuiting final operation.
noneMatch() or allMatch() can be used for this when using a Predicate with a side effect. Admittedly also not the cleanest solution, but it does achieve the goal :
AtomicInteger product = new AtomicInteger(1);
IntStream.of(2, 3, 4, 5, 0, 7, 8)
.peek(System.out::println)
.noneMatch(i -> {
if (i == 0) {
product.set(0);
return true;
}
int oldValue = product.get();
while (oldValue != 0 && !product.compareAndSet(oldValue, i * oldValue)) {
oldValue = product.get();
}
return oldValue == 0;
});
System.out.println("Result: " + product.get());
It short-circuits and can be made parallel.
Upvotes: 3
Reputation: 12122
A general short-circuiting static reduce method can be implemented using the spliterator of a stream. It even turned out to be not very complicated! Using spliterators seems to be the way to go a lot of times when one wants to work with steams in a more flexible way.
public static <T> T reduceWithCancel(Stream<T> s, T acc, BinaryOperator<T> op, Predicate<? super T> cancelPred) {
BoxConsumer<T> box = new BoxConsumer<T>();
Spliterator<T> splitr = s.spliterator();
while (!cancelPred.test(acc) && splitr.tryAdvance(box)) {
acc = op.apply(acc, box.value);
}
return acc;
}
public static class BoxConsumer<T> implements Consumer<T> {
T value = null;
public void accept(T t) {
value = t;
}
}
Usage:
int product = reduceWithCancel(
Stream.of(1, 2, 0, 3, 4).peek(System.out::println),
1, (acc, i) -> acc * i, i -> i == 0);
System.out.println("Result: " + product);
Output:
1
2
0
Result: 0
The method could be generalised to perform other kinds of terminal operations.
This is based loosely on this answer about a take-while operation.
I don't know anything about the parallelisation potential of this.
Upvotes: 7