17hao
17hao

Reputation: 309

Why the output of this java program different from except one?

I'm writing some code to simulate CAS(compare and swap).
Here I have a method cas to simulate CAS instruction, a method increase to plus field count 1. And I start 2 threads that every thread add field count 10000 times.
The problem is that the expected output is 20000, but the actual output is a little bit smaller than 20000. For example 19984, 19992, 19989...Every time is different.
I would very appreciate it if you can help me .

public class SimulateCAS {
    private volatile int count;

    private synchronized int cas(int expectation, int newValue) {
        int curValue = count;
        if (expectation == curValue) {
            count = newValue;
        }
        return curValue;
    }

    void increase() {
        int newValue;
        do {
            newValue = count + 1;                       // ①
        } while (count != cas(count, newValue));        // ②
    }

    public static void main(String[] args) throws InterruptedException {
        final SimulateCAS demo = new SimulateCAS();
        Thread t1 = new Thread(() -> {
            for (int i = 0; i < 10000; i++) {
                demo.add10k();
            }
        });
        Thread t2 = new Thread(() -> {
            for (int i = 0; i < 10000; i++) {
                demo.add10k();
            }
        });

        t1.start();
        t2.start();
        t1.join();
        t2.join();
        System.out.println(demo.count);
    }
}

Upvotes: 2

Views: 90

Answers (1)

Johannes Kuhn
Johannes Kuhn

Reputation: 15202

The problem is your increase method.

The value of count can be updated at any point between the lines with the comment ① and ②.
Your implementation of increase assumes that this can not happen, and that the count in line ① is the same count as in line ②.

A better implementation increase would be

void increase() {
    int oldValue, newValue;
    do {
        oldValue = count;  // get the current value
        newValue = oldValue + 1; // calculate the new value based on the old
    } while (oldValue != cas(oldValue, newValue)); // Do a compare and swap - if the oldValue is still the current value, change it to the newValue, otherwise not.
}

Here your full code with a real CAS, so no locks are needed.

Upvotes: 2

Related Questions