Jordi Reinsma
Jordi Reinsma

Reputation: 183

Passing a WaitGroup to a function changes behavior, why?

I have 3 merge sort implementations:

MergeSort: simple one without concurrency;

MergeSortSmart: with concurrency limited by buffered channel size limit. If buffer is full, calls the simple implementation;

MergeSortSmartBug: same strategy as the previous one, but with a small "refactor", passing wg pointer to a function reducing code duplication.

The first two works as expected, but the third one returns an empty slice instead of the sorted input. I couldn't understand what happened and found no answers as well.

Here is the playground link for the code: https://play.golang.org/p/DU1ypbanpVi

package main

import (
    "fmt"
    "math/rand"
    "runtime"
    "sync"
)

type pass struct{}

var semaphore = make(chan pass, runtime.NumCPU())

func main() {
    rand.Seed(10)
    s := make([]int, 16)
    for i := 0; i < 16; i++ {
        s[i] = int(rand.Int31n(1000))
    }

    fmt.Println(s)
    fmt.Println(MergeSort(s))
    fmt.Println(MergeSortSmart(s))
    fmt.Println(MergeSortSmartBug(s))
}

func merge(l, r []int) []int {
    tmp := make([]int, 0, len(l)+len(r))
    for len(l) > 0 || len(r) > 0 {
        if len(l) == 0 {
            return append(tmp, r...)
        }
        if len(r) == 0 {
            return append(tmp, l...)
        }
        if l[0] <= r[0] {
            tmp = append(tmp, l[0])
            l = l[1:]
        } else {
            tmp = append(tmp, r[0])
            r = r[1:]
        }
    }
    return tmp
}

func MergeSort(s []int) []int {
    if len(s) <= 1 {
        return s
    }

    n := len(s) / 2

    l := MergeSort(s[:n])
    r := MergeSort(s[n:])

    return merge(l, r)
}

func MergeSortSmart(s []int) []int {
    if len(s) <= 1 {
        return s
    }

    n := len(s) / 2

    var wg sync.WaitGroup
    wg.Add(2)

    var l, r []int
    select {
    case semaphore <- pass{}:
        go func() {
            l = MergeSortSmart(s[:n])
            <-semaphore
            wg.Done()
        }()
    default:
        l = MergeSort(s[:n])
        wg.Done()
    }

    select {
    case semaphore <- pass{}:
        go func() {
            r = MergeSortSmart(s[n:])
            <-semaphore
            wg.Done()
        }()
    default:
        r = MergeSort(s[n:])
        wg.Done()
    }

    wg.Wait()
    return merge(l, r)
}

func MergeSortSmartBug(s []int) []int {
    if len(s) <= 1 {
        return s
    }

    n := len(s) / 2

    var wg sync.WaitGroup
    wg.Add(2)

    l := mergeSmart(s[:n], &wg)
    r := mergeSmart(s[n:], &wg)

    wg.Wait()
    return merge(l, r)
}

func mergeSmart(s []int, wg *sync.WaitGroup) []int {
    var tmp []int
    select {
    case semaphore <- pass{}:
        go func() {
            tmp = MergeSortSmartBug(s)
            <-semaphore
            wg.Done()
        }()
    default:
        tmp = MergeSort(s)
        wg.Done()
    }
    return tmp
}

Why does the Bug version returns an empty slice? How can I refactor the Smart version without doing two selects one after the other?

Sorry for I couldn't reproduce this behavior in a smaller example.

Upvotes: 1

Views: 393

Answers (4)

Jordi Reinsma
Jordi Reinsma

Reputation: 183

I implemented both suggestions (passing slice by reference and using channels) and the (working!) result is here: https://play.golang.org/p/DcDC_-NjjAH

package main

import (
    "fmt"
    "math/rand"
    "runtime"
    "sync"
)

type pass struct{}

var semaphore = make(chan pass, runtime.NumCPU())

func main() {
    rand.Seed(10)
    s := make([]int, 16)
    for i := 0; i < 16; i++ {
        s[i] = int(rand.Int31n(1000))
    }

    fmt.Println(s)
    fmt.Println(MergeSort(s))
    fmt.Println(MergeSortSmart(s))
    fmt.Println(MergeSortSmartPointer(s))
    fmt.Println(MergeSortSmartChan(s))
}

func merge(l, r []int) []int {
    tmp := make([]int, 0, len(l)+len(r))
    for len(l) > 0 || len(r) > 0 {
        if len(l) == 0 {
            return append(tmp, r...)
        }
        if len(r) == 0 {
            return append(tmp, l...)
        }
        if l[0] <= r[0] {
            tmp = append(tmp, l[0])
            l = l[1:]
        } else {
            tmp = append(tmp, r[0])
            r = r[1:]
        }
    }
    return tmp
}

func MergeSort(s []int) []int {
    if len(s) <= 1 {
        return s
    }

    n := len(s) / 2

    l := MergeSort(s[:n])
    r := MergeSort(s[n:])

    return merge(l, r)
}

func MergeSortSmart(s []int) []int {
    if len(s) <= 1 {
        return s
    }

    n := len(s) / 2

    var wg sync.WaitGroup
    wg.Add(2)

    var l, r []int
    select {
    case semaphore <- pass{}:
        go func() {
            l = MergeSortSmart(s[:n])
            <-semaphore
            wg.Done()
        }()
    default:
        l = MergeSort(s[:n])
        wg.Done()
    }

    select {
    case semaphore <- pass{}:
        go func() {
            r = MergeSortSmart(s[n:])
            <-semaphore
            wg.Done()
        }()
    default:
        r = MergeSort(s[n:])
        wg.Done()
    }

    wg.Wait()
    return merge(l, r)
}

func MergeSortSmartPointer(s []int) []int {
    if len(s) <= 1 {
        return s
    }
    n := len(s) / 2
    var l, r []int

    var wg sync.WaitGroup
    wg.Add(2)

    mergeSmartPointer(&l, s[:n], &wg)
    mergeSmartPointer(&r, s[n:], &wg)

    wg.Wait()
    return merge(l, r)
}

func mergeSmartPointer(tmp *[]int, s []int, wg *sync.WaitGroup) {
    select {
    case semaphore <- pass{}:
        go func() {
            *tmp = MergeSortSmartPointer(s)
            <-semaphore
            wg.Done()
        }()
    default:
        *tmp = MergeSort(s)
        wg.Done()
    }
}

func MergeSortSmartChan(s []int) []int {
    if len(s) <= 1 {
        return s
    }
    n := len(s) / 2

    lchan := make(chan []int)
    rchan := make(chan []int)

    go mergeSmartChan(s[:n], lchan)
    go mergeSmartChan(s[n:], rchan)

    l := <-lchan
    r := <-rchan

    return merge(l, r)
}

func mergeSmartChan(s []int, c chan []int) {
    select {
    case semaphore <- pass{}:
        go func() {
            c <- MergeSortSmartChan(s)
            <-semaphore
        }()
    default:
        c <- MergeSort(s)
    }
}

I understood 100% what I was doing wrong, thanks! And for future references, here's the benchmark of sorting a slice of 100,000 elems:

$ go test -bench=.
goos: linux
goarch: amd64
cpu: Intel(R) Core(TM) i5-9300H CPU @ 2.40GHz
BenchmarkMergeSort-8                      97      12230309 ns/op
BenchmarkMergeSortSmart-8                181       7209844 ns/op
BenchmarkMergeSortSmartPointer-8         163       7483136 ns/op
BenchmarkMergeSortSmartChan-8            156       8149585 ns/op

Upvotes: 0

Bryan Austin
Bryan Austin

Reputation: 497

The problem is not with the WaitGroup itself. It's with your concurrency handling. Your mergeSmart function lunches a go routine and returns the tmp variable without waiting for the go routine to finish. You might want to try a pattern more like this:

leftchan := make(chan []int)
rightchan := make(chan []int)
go mergeSmart(s[:n], leftchan)
go mergeSmart(s[n:], rightchan)
l := <-leftchan
r := <-rightchan

Or you can use a single channel if order doesn't matter.

Upvotes: 2

Fulldump
Fulldump

Reputation: 883

Look at the mergeSmart function. When the select enter into the first case, the goroutine is launched and imediatly returns tmp (which is an empty array). In that case there is no way to get the right value. (See advanced debugging prints here https://play.golang.org/p/IedaY3muso2) Maybe passing arrays preallocated by reference?

Upvotes: 1

hobbs
hobbs

Reputation: 239801

mergeSmart doesn't wait on the wg, so it returns a tmp that hasn't received a value yet. You could probably repair it by passing a reference to the destination slice in to the function, instead of returning a slice.

Upvotes: 1

Related Questions