Reputation:
Let's say I have a function IsAPrimaryColour() which works by calling three other functions IsRed(), IsGreen() and IsBlue(). Since the three functions are quite independent of one another, they can run concurrently. The return conditions are:
The thing I'm struggling with is how to exit the function if any other three functions return true, but also to wait for all three to finish if they all return false. If I use a sync.WaitGroup object, I will need to wait for all 3 go routines to finish before I can return from the calling function.
Therefore, I'm using a loop counter to keep track of how many times I have received a message on a channel and existing the program once I have received all 3 messages.
https://play.golang.org/p/kNfqWVq4Wix
package main
import (
"errors"
"fmt"
"time"
)
func main() {
x := "something"
result, err := IsAPrimaryColour(x)
if err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Printf("Result: %v\n", result)
}
}
func IsAPrimaryColour(value interface{}) (bool, error) {
found := make(chan bool, 3)
errors := make(chan error, 3)
defer close(found)
defer close(errors)
var nsec int64 = time.Now().UnixNano()
//call the first function, return the result on the 'found' channel and any errors on the 'errors' channel
go func() {
result, err := IsRed(value)
if err != nil {
errors <- err
} else {
found <- result
}
fmt.Printf("IsRed done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
}()
//call the second function, return the result on the 'found' channel and any errors on the 'errors' channel
go func() {
result, err := IsGreen(value)
if err != nil {
errors <- err
} else {
found <- result
}
fmt.Printf("IsGreen done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
}()
//call the third function, return the result on the 'found' channel and any errors on the 'errors' channel
go func() {
result, err := IsBlue(value)
if err != nil {
errors <- err
} else {
found <- result
}
fmt.Printf("IsBlue done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
}()
//loop counter which will be incremented every time we read a value from the 'found' channel
var counter int
for {
select {
case result := <-found:
counter++
fmt.Printf("received a value on the results channel after %f nanoseconds. Value of counter is %d\n", float64(time.Now().UnixNano()-nsec), counter)
if result {
fmt.Printf("some goroutine returned true\n")
return true, nil
}
case err := <-errors:
if err != nil {
fmt.Printf("some goroutine returned an error\n")
return false, err
}
default:
}
//check if we have received all 3 messages on the 'found' channel. If so, all 3 functions must have returned false and we can thus return false also
if counter == 3 {
fmt.Printf("all goroutines have finished and none of them returned true\n")
return false, nil
}
}
}
func IsRed(value interface{}) (bool, error) {
return false, nil
}
func IsGreen(value interface{}) (bool, error) {
time.Sleep(time.Millisecond * 100) //change this to a value greater than 200 to make this function take longer than IsBlue()
return true, nil
}
func IsBlue(value interface{}) (bool, error) {
time.Sleep(time.Millisecond * 200)
return false, errors.New("something went wrong")
}
Although this works well enough, I wonder if I'm not overlooking some language feature to do this in a better way?
Upvotes: 2
Views: 2631
Reputation: 5197
errgroup.WithContext
can help simplify the concurrency here.
You want to stop all of the goroutines if an error occurs, or if a result is found. If you can express “a result is found” as a distinguished error (along the lines of io.EOF
), then you can use errgroup
's built-in “cancel on first error” behavior to shut down the whole group:
func IsAPrimaryColour(ctx context.Context, value interface{}) (bool, error) {
var nsec int64 = time.Now().UnixNano()
errFound := errors.New("result found")
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
result, err := IsRed(ctx, value)
if result {
err = errFound
}
fmt.Printf("IsRed done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
return err
})
…
err := g.Wait()
if err == errFound {
fmt.Printf("some goroutine returned errFound\n")
return true, nil
}
if err != nil {
fmt.Printf("some goroutine returned an error\n")
return false, err
}
fmt.Printf("all goroutines have finished and none of them returned true\n")
return false, nil
}
(https://play.golang.org/p/MVeeBpDv4Mn)
Upvotes: 4
Reputation: 79734
The idiomatic way to handle multiple concurrent function calls, and cancel any outstanding after a condition, is with the use of a context value. Something like this:
func operation1(ctx context.Context) bool { ... }
func operation2(ctx context.Context) bool { ... }
func operation3(ctx context.Context) bool { ... }
func atLeastOneSuccess() bool {
ctx, cancel := context.WithCancel(context.Background()
defer cancel() // Ensure any functions still running get the signal to stop
results := make(chan bool, 3) // A channel to send results
go func() {
results <- operation1(ctx)
}()
go func() {
results <- operation2(ctx)
}()
go func() {
results <- operation3(ctx)
}()
for i := 0; i < 3; i++ {
result := <-results
if result {
// One of the operations returned success, so we'll return that
// and let the deferred call to cancel() tell any outstanding
// functions to abort.
return true
}
}
// We've looped through all return values, and they were all false
return false
}
Of course this assumes that each of the operationN
functions actually honors a canceled context. This answer discusses how to do that.
Upvotes: 1
Reputation:
some remarks,
So once you got ride of all the fat, the code looks like
func IsAPrimaryColour(value interface{}) (bool, error) {
fns := []func(interface{}) (bool, error){IsRed, IsGreen, IsBlue}
found := make(chan bool, len(fns))
errors := make(chan error, len(fns))
for i := 0; i < len(fns); i++ {
fn := fns[i]
go func() {
result, err := fn(value)
if err != nil {
errors <- err
return
}
found <- result
}()
}
for i := 0; i < len(fns); i++ {
select {
case result := <-found:
if result {
return true, nil
}
case err := <-errors:
if err != nil {
return false, err
}
}
}
return false, nil
}
func main() {
now := time.Now()
x := "something"
result, err := IsAPrimaryColour(x)
if err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Printf("Result: %v\n", result)
}
fmt.Println("it took", time.Since(now))
}
https://play.golang.org/p/bARHS6c6m1c
Upvotes: 1
Reputation: 46562
You don't have to block the main goroutine on the Wait
, you could block something else, for example:
doneCh := make(chan struct{}{})
go func() {
wg.Wait()
close(doneCh)
}()
Then you can wait on doneCh
in your select
to see if all the routines have finished.
Upvotes: 0