Reduction behaves strangely when using parallel stream but works fine for sequential stream in Java 8u5

class Foo{
    int len;
}
public class Main {
    public static void main(String[] args) throws Exception{
    System.out.println(Stream.of("alpha", "beta", "gamma", "delta").parallel().reduce(
            new Foo(),
            (f, s) -> { f.len += s.length(); return f; },
            (f1, f2) -> {
                Foo f = new Foo();
                /* check self-reduction
                if (f1 == f2) { 
                    System.out.println("equal");
                    f.len = f1.len;
                    return f;
                }
                */
                f.len = f1.len + f2.len;
                return f;
            }
    ).len);
}

The code tries to count the total length of several strings.

This piece of code prints 19 only if
1.I use sequential stream (by removing the "parallel()" function call)
or
2.I use Integer instead of Foo which is simply a wrapper around an int.

Otherwise the console will print 20 or 36 instead. To debug this issue, I added the code "check self-reduction" which does change the output: "equal" always gets printed twice. The console will sometimes print 8, sometimes 10.

My understanding is that reduce() is a Java implementation of parallel foldr/foldl. The 3rd argument of reduce(), combiner is used to merge results of parallel execution of reduction. Is that right? If so, why would the result of reduction ever need to combine with itself? Further, how do I fix this code so that it gives correct output and still runs parallel?

EDIT: Please ignore the fact that I did not use method reference to simplify the code, as my ultimate goal was to zip by adding more fields to Foo.

Upvotes: 2

Views: 178

Answers (2)

Brian Goetz
Brian Goetz

Reputation: 95376

Your code is horribly broken. You are using a reducer function which fails the requirement that the accumulator/combiner functions be associative, stateless, and non-interfering. And a mutable Foo is not an identity for the reduction. All of these can lead to incorrect results when executed in parallel.

You're also making it far harder than you need to! Try this:

int totalLen = 
    Stream.of(... stuff ...)
          .parallel()
          .mapToInt(String::length)
          .sum();

or

int totalLen = 
    Stream.of(... stuff ...)
          .parallel()
          .mapToInt(String::length)
          .reduce(0, Integer::sum);

Further, you're trying to use reduce which reduces over values (which is why it works with Integer), but you're trying to use mutable state containers for your reduction result. If you want to reduce into a mutable state container (like a List or StringBuilder), use collect() instead, which is designed for mutation.

Upvotes: 9

ajb
ajb

Reputation: 31699

I think the problem is that the "identity" Foo is being reused too much.

Here's a modification where each Foo is given its own ID number so that we can track it:

class Foo {
    private static int currId = 0;
    private static Object lock = new Object();
    int id;
    int len;
    public Foo() {
        synchronized(lock) {
            id = currId++;
        }
    }    
}

public class Main {
    public static void main(String[] args) throws Exception{
    System.out.println(Stream.of("alpha", "beta", "gamma", "delta").parallel().reduce(
            new Foo(),
            (f, s) -> {
                System.out.println("Adding to #" + f.id + ": " +
                     f.len + " + " + s.length() + " => " + (f.len+s.length())); 
                f.len += s.length(); return f; },
            (f1, f2) -> {
                Foo f = new Foo();
                f.len = f1.len + f2.len;
                System.out.println("Creating new #" + f.id + " from #" + f1.id + " and #" + f2.id + ": " +
                    f1.len + " + " + f2.len + " => " + (f1.len+f2.len));
                return f;
            }
    ).len);
}

The output I get is:

Adding to #0: 0 + 5 => 5
Adding to #0: 0 + 4 => 4
Adding to #0: 5 + 5 => 10
Adding to #0: 9 + 5 => 14
Creating new #2 from #0 and #0: 19 + 19 => 38
Creating new #1 from #0 and #0: 14 + 14 => 28
Creating new #3 from #2 and #1: 38 + 28 => 66
66

It's not consistent every time. The thing I notice is that each time you say f.len += s.length(), it adds to the same Foo, which means that the first new Foo() is being executed only once, and lengths keep getting added into it, so that the same input strings' lengths are counted multiple times. Since there are apparently multiple parallel threads accessing it at the same time, the results above are a little strange and change from run to run.

Upvotes: 0

Related Questions