gorilla_glue
gorilla_glue

Reputation: 355

Transforming `PCollection` with many elements into a single element

I am trying to convert a PCollection, that has many elements, into a PCollection that has one element. Basically, I want to go from:

[1,2,3,4,5,6]

to:

[[1,2,3,4,5,6]]

so that I can work with the entire PCollection in a DoFn.

I've tried CombineGlobally(lamdba x: x), but only a portion of elements get combined into an array at a time, giving me the following result:

[1,2,3,4,5,6] -> [[1,2],[3,4],[5,6]]

Or something to that effect.

This is my relevant portion of my script that I'm trying to run:

import apache_beam as beam
raw_input = range(1024)
def run_test():
    with TestPipeline() as test_pl:
        input = test_pl | "Create" >> beam.Create(raw_input)
        def combine(x):
            print(x)
            return x
        
        (
            input
            | "Global aggregation" >> beam.CombineGlobally(combine)
        )
        pl.run()

run_test()

Upvotes: 2

Views: 1916

Answers (2)

robertwb
robertwb

Reputation: 5104

You can also accomplish what you want with side inputs, e.g.

with beam.Pipeline() as p:
    pcoll = ...
    (p
     # Create a PCollection with a single element.
     | beam.Create([None])
     # This will process the singleton exactly once,
     # with the entirity of pcoll passed in as a second argument as a list.
     | beam.Map(
        lambda _, pcoll_as_side: ...consume pcoll_as_side here...,
        pcoll_as_side=beam.pvalue.AsList(pcoll))

Upvotes: 2

gorilla_glue
gorilla_glue

Reputation: 355

I figured out a pretty painless way to do this, which I missed in the docs:

The more general way to combine elements, and the most flexible, is with a class that inherits from CombineFn.

CombineFn.create_accumulator(): This creates an empty accumulator. For example, an empty accumulator for a sum would be 0, while an empty accumulator for a product (multiplication) would be 1.

CombineFn.add_input(): Called once per element. Takes an accumulator and an input element, combines them and returns the updated accumulator.

CombineFn.merge_accumulators(): Multiple accumulators could be processed in parallel, so this function helps merging them into a single accumulator.

CombineFn.extract_output(): It allows to do additional calculations before extracting a result.

I suppose supplying a lambda function that simply passes its argument to the "vanilla" CombineGlobally wouldn't do what I expected initially. That functionality has to be specified by me (although I still think it's weird this isn't built into the API).

You can find more about subclassing CombineFn here, which I found very helpful:

A CombineFn specifies how multiple values in all or part of a PCollection can be merged into a single value—essentially providing the same kind of information as the arguments to the Python “reduce” builtin (except for the input argument, which is an instance of CombineFnProcessContext). The combining process proceeds as follows:

  1. Input values are partitioned into one or more batches.
  2. For each batch, the create_accumulator method is invoked to create a fresh initial “accumulator” value representing the combination of zero values.
  3. For each input value in the batch, the add_input method is invoked to combine more values with the accumulator for that batch.
  4. The merge_accumulators method is invoked to combine accumulators from separate batches into a single combined output accumulator value, once all of the accumulators have had all the input value in their batches added to them. This operation is invoked repeatedly, until there is only one accumulator value left.
  5. The extract_output operation is invoked on the final accumulator to get the output value. Note: If this CombineFn is used with a transform that has defaults, apply will be called with an empty list at expansion time to get the default value.

So, by subclassing CombineFn, I wrote this simple implementation, Aggregated, that does exactly what I want:

import apache_beam as beam
raw_input = range(1024)


class Aggregated(beam.CombineFn):
    def create_accumulator(self):
        return []
    
    def add_input(self, accumulator, element):
        accumulator.append(element)
        return accumulator
    
    def merge_accumulators(self, accumulators):
        merged = []
        for a in accumulators:
            for item in a:
                merged.append(item)
        return merged
    
    def extract_output(self, accumulator):
        return accumulator


def run_test():
    with TestPipeline() as test_pl:
        input = test_pl | "Create" >> beam.Create(raw_input)
        (
            input
            | "Global aggregation" >> beam.CombineGlobally(Aggregated())
            | "print" >> beam.Map(print)
        )
        pl.run()

run_test()

Upvotes: 3

Related Questions