temporarya
temporarya

Reputation: 755

How to implement pipeline to goroutines?

I need some help on understanding how to use pipeline to get data to transfer from one goroutine to another.

I read the golang blogpost on pipeline, I understood it but couldn't fully put it into action and thus thought seeking help from the community.

Now, I have come up with this ugly code ( Playground ) :

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    wg := sync.WaitGroup{}
    ch := make(chan int)
    for a := 0; a < 3; a++ {
        wg.Add(1)
        go func1(int(3-a), ch, &wg)
    }
    go func() {
        wg.Wait()
        close(ch)
    }()
    wg2 := sync.WaitGroup{}
    ch2 := make(chan string)
    for val := range ch {
        fmt.Println(val)
        wg2.Add(1)
        go func2(val, ch2, &wg2)
    }
    go func() {
        wg2.Wait()
        close(ch2)
    }()
    for val := range ch2 {
        fmt.Println(val)
    }
}

func func1(seconds int, ch chan<- int, wg *sync.WaitGroup) {
    defer wg.Done()
    time.Sleep(time.Duration(seconds) * time.Second)
    ch <- seconds
}

func func2(seconds int, ch chan<- string, wg *sync.WaitGroup) {
    defer wg.Done()
    ch <- "hello"
}

Problem

I want to do it the proper way using pipelines or whatever is the proper way to do it.

Also, the pipeline shown in the blogpost isn't for goroutines and thus I am not able to do it myself.

In real life those func1 and func2 are functions which fetch resources from the web and hence they're launched in their own goroutine.

Thanks.
Temporarya
( A golang noobie )

P.S. Real life examples and usage of pipeline using goroutines would be of great help too.

Upvotes: 0

Views: 1000

Answers (2)

Jerry An
Jerry An

Reputation: 1432

This article covers the pipeline pattern in a port scanner example, recommend to look at it.

A port scanner is designed to probe a server or host for open ports

enter image description here

The image above shows the whole pipeline of the port scanner. Let’s explain each related function one by one in the next section.

  • Func init
  • Func parsePortsToScan
  • struct scanOp
  • Func gen
  • Func scan
  • Func filter
  • Func store
  • Func main

The init function defines the arguments passed in by the user. The ports variable is a string of the ports to scan separated by a dash. The outFile variable is the file to write the results.

var ports string
var outFile string

func init() {
    flag.StringVar(&ports, "ports", "80", "Port(s) (e.g. 80, 22-100).")
    flag.StringVar(&outFile, "outfile", "scans.csv", 
    "Destination of scan results (defaults to scans.csv)")
}

The main function is responsible for executing the pipeline of functions. It takes a slice of int ports and a string outfile from the command line arguments.

func main() {
    flag.Parse()

    portsToScan, err := parsePortsToScan(ports)
    if err != nil {
        fmt.Printf("Failed to parse ports to scan: %s\n", err)
        os.Exit(1)
    }

    dest, err := os.Create(outFile)
    if err != nil {
        fmt.Printf("Failed to create scan results destination: %s\n", err)
        os.Exit(2)
    }

    // pipeline
    // scanChan := store(dest, filter(scan(gen(portsToScan...))))

    // broken up for explainability
    var scanChan <-chan scanOp
    scanChan = gen(portsToScan...)
    scanChan = scan(scanChan)
    scanChan = filter(scanChan)
    scanChan = store(dest, scanChan)

    for s := range scanChan {
        if !s.open && s.scanErr != fmt.Sprintf("dial tcp 127.0.0.1:%d: connect: connection refused", s.port) {
            fmt.Println(s.scanErr)
        }
    }
}

The parsePortsToScan function parses the ports to scan from the command line argument. If the argument is invalid, an error is returned. If the argument is valid, a slice of ints is returned.

func parsePortsToScan(portsFlag string) ([]int, error) {
    p, err := strconv.Atoi(portsFlag)
    if err == nil {
        return []int{p}, nil
    }

    ports := strings.Split(portsFlag, "-")
    if len(ports) != 2 {
        return nil, errors.New("unable to determine port(s) to scan")
    }

    minPort, err := strconv.Atoi(ports[0])
    if err != nil {
        return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[0])
    }

    maxPort, err := strconv.Atoi(ports[1])
    if err != nil {
        return nil, fmt.Errorf("failed to convert %s to a valid port number", ports[1])
    }

    if minPort <= 0 || maxPort <= 0 {
        return nil, fmt.Errorf("port numbers must be greater than 0")
    }

    var results []int
    for p := minPort; p <= maxPort; p++ {
        results = append(results, p)
    }
    return results, nil
}

scanOp represents a single port scan operation and its results (open, scanErr, scanDuration). open is a boolean indicating whether or not the port is open. scanErr is an error message if the scan fails.scanDuration is the time it took to perform the scan.

To output results to a CSV file, there are two methods used by the CSV writer. csvHeaders returns the headers in a slice of strings. asSlice returns the value fields of scanOp as a slice of strings.

type scanOp struct {
    port         int
    open         bool
    scanErr      string
    scanDuration time.Duration
}

func (so scanOp) csvHeaders() []string {
    return []string{"port", "open", "scanError", "scanDuration"}
}

func (so scanOp) asSlice() []string {
    return []string{
        strconv.FormatInt(int64(so.port), 10),
        strconv.FormatBool(so.open),
        so.scanErr,
        so.scanDuration.String(),
    }
}

The gen function is a generator function that returns a buffered channel of scanOps struct values from a slice of int ports. It is used to create a pipeline of functions that will be executed in sequence, and it is the first function in the pipeline.

func gen(ports ...int) <-chan scanOp {
    out := make(chan scanOp, len(ports))
    go func() {
        defer close(out)
        for _, p := range ports {
            out <- scanOp{port: p}
        }
    }()
    return out
}

The scan function is responsible for performing the actual port scan. It takes a buffered channel of scanOps and returns an unbuffered channel of scanOps.

func scan(in <-chan scanOp) <-chan scanOp {
    out := make(chan scanOp)
    go func() {
        defer close(out)
        for scan := range in {
            address := fmt.Sprintf("127.0.0.1:%d", scan.port)
            start := time.Now()
            conn, err := net.Dial("tcp", address)
            scan.scanDuration = time.Since(start)
            if err != nil {
                scan.scanErr = err.Error()
            } else {
                conn.Close()
                scan.open = true
            }
            out <- scan
        }
    }()
    return out
}

The filter function is responsible for filtering scanOps that are open.

func filter(in <-chan scanOp) <-chan scanOp {
    out := make(chan scanOp)
    go func() {
        defer close(out)
        for scan := range in {
            if scan.open {
                out <- scan
            }
        }
    }()
    return out
}

The store function is responsible for storing scanOps in a CSV file. It is the last function in the pipeline.

func store(file io.Writer, in <-chan scanOp) <-chan scanOp {
    csvWriter := csv.NewWriter(file)
    out := make(chan scanOp)
    go func() {
        defer csvWriter.Flush()
        defer close(out)
        var headerWritten bool
        for scan := range in {
            if !headerWritten {
                headers := scan.csvHeaders()
                if err := csvWriter.Write(headers); err != nil {
                    fmt.Println(err)
                    break
                }
                headerWritten = true
            }
            values := scan.asSlice()
            if err := csvWriter.Write(values); err != nil {
                fmt.Println(err)
                break
            }
        }
    }()

    return out
}

Channels can be used to connect goroutines together so that the output of one is the input to another. It is super helpful when you have many functions in your pipeline and want to connect them.

Upvotes: 0

David Maze
David Maze

Reputation: 158686

The key pattern of that pipelines post is that you can view the contents of a channel as a stream of data, and write a set of cooperating goroutines that build up a data-processing stream graph. This can be a way to get some concurrency into a data-oriented application.

In terms of design, you may also find it more helpful to build up blocks that aren't tied to the goroutine structure, and wrap them in channels. This makes it much easier to test the lower-level code, and if you change your mind about running things in a goroutine or not, it's easier to add or remove the wrapper.

So in your example I'd start by refactoring the lowest-level tasks out into their own (synchronous) functions:

func fetch(ms int) int {
    time.Sleep(time.Duration(ms) * time.Millisecond)
    return ms
}

func report(ms int) string {
    return fmt.Sprintf("Hello after %d ms", ms)
}

Since the second half of your example is fairly synchronous, it's easy to adapt to the pipeline pattern. We write a function that consumes all of its input stream and produces a complete output stream, closing it when it's done.

func reportAll(mss <-chan int, out chan<- string) {
    for ms := range mss {
        out <- report(ms)
    }
    close(out)
}

The function that calls the asynchronous code is a little tricker. In the main loop of the function, every time you read a value, you need to launch a goroutine to process it. Then after you've read everything out of the input channel you need to wait for all of those goroutines to finish before closing the output channel. You can use a small anonymous function here to help.

func fetchAll(mss <-chan int, out chan<- int) {
    var wg sync.WaitGroup
    for ms := range mss {
        wg.Add(1)
        go func(ms int) {
            out <- fetch(ms)
            wg.Done()
        }(ms)
    }
    wg.Wait()
    close(out)
}

It's also helpful here (because channel writes are blocking) to write another function to seed the input values.

func produceInputs(mss chan<- int) {
    for ms := 1000; ms > 0; ms -= 300 {
        mss <- ms
    }
    close(mss)
}

Now your main function needs to create the channels between these and run the final consumer.

// main is the entry point to the program.
//
//                   mss        fetched       results
//     produceInputs --> fetchAll --> reportAll --> main
func main() {
    mss := make(chan int)
    fetched := make(chan int)
    results := make(chan string)

    go produceInputs(mss)
    go fetchAll(mss, fetched)
    go reportAll(fetched, results)

    for val := range results {
        fmt.Println(val)
    }
}

https://play.golang.org/p/V9Z7ECUVIJL is a complete example.

I've avoided manually passing around sync.WaitGroups here (and tend to do that in general: you wouldn't have a WaitGroup unless you're explicitly calling something as the top level of a goroutine, so pushing the WaitGroup management up to the caller makes the code more modular; see my fetchAll function above for an example). How do I know all of my goroutines have finished? We can trace through:

  • If I've reached the end of main, the results channel is closed.
  • The results channel is the output channel of reportAll; if it closed, then that function reached the end of its execution; and if that happened then the fetched channel is closed.
  • The fetched channel is the output channel of fetchAll; ...

Another way to look at this is that as soon as the pipeline's source (produceInputs) closes its output channel and finishes, that "I'm done" signal flows down the pipeline and causes the downstream steps to close their output channels and finish too.

The blog post mentions a separate explicit close channel. I haven't gone into that here at all. Since it was written, though, the standard library gained the context package, which is now the standard idiom for managing those. You'd need to use a select statement in the body of the main loop, which makes the handling a little more complicated. This might look like:

func reportAllCtx(ctx context.Context, mss <-chan int, out chan<- string) {
    for {
        select {
            case <-ctx.Done():
                break
            case ms, ok := <-mss:
                if ok {
                    out <- report(ms)
                } else {
                    break
                }
            }
        }
    }
    close(out)
}

Upvotes: 3

Related Questions