SamsonStannus
SamsonStannus

Reputation: 53

Beam - Functions to run only once at the start and end of a Beam Pipeline

I have a Beam pipeline that queries BigQuery and then upload results to BigTable. I'd like to scale out my BigTable instance (from 1 to 10 nodes) before my pipeline starts and then scale back down (from 10 to 1 node) after the results are loaded in to BigTable. Is there any mechanism to do this with Beam?

I'd essentially like to either have two separate transforms one at the beginning of the pipeline and one at the end that scale up and down the nodes, respectively. Or, have a DoFn that only triggers setup() and teardown() on one worker.

I've attempted to use the setup() and teardown() of the DoFn lifecycle functions. But, these functions get executed once per worker (and I use hundreds of workers), so it will attempt to scale up and down BigTable multiple times (and hit the instance and cluster write quotas for the day). So that doesn't really work with my use case. In any case here's a snippet of a BigTableWriteFn I've been experimenting with:

class _BigTableWriteFn(beam.DoFn):

    def __init__(self, project_id, instance_id, table_id, cluster_id, node_count):
        beam.DoFn.__init__(self)
        self.beam_options = {
            'project_id': project_id,
            'instance_id': instance_id,
            'table_id': table_id,
            'cluster_id': cluster_id,
            'node_count': node_count
        }
        self.table = None
        self.initial_node_count = None
        self.batcher = None
        self.written = Metrics.counter(self.__class__, 'Written Row')

    def setup(self):
        client = Client(project=self.beam_options['project_id'].get(), admin=True)
        instance = client.instance(self.beam_options['instance_id'].get())
        node_count = self.beam_options['node_count'].get()
        cluster = instance.cluster(self.beam_options['cluster_id'].get())
        self.initial_node_count = cluster.serve_nodes
        if node_count != self.initial_node_count:  # I realize this logic is flawed since the cluster.serve_nodes will change after the first setup() call, but I first thought setup() and teardown() was run once for the whole transform...
            cluster.serve_nodes = node_count
            cluster.update()

    ## other life cycle methods in between but aren't important to the question

    def teardown(self):
        client = Client(project=self.beam_options['project_id'].get(), admin=True)
        instance = client.instance(self.beam_options['instance_id'].get())
        cluster = instance.cluster(self.beam_options['cluster_id'].get())
        if cluster.serve_nodes != self.initial_node_count: # I realize this logic is flawed since the cluster.serve_nodes will change after the first setup() call, but I first thought setup() and teardown() was run once for the whole transform...
            cluster.serve_nodes = self.initial_node_count
            cluster.update()

I'm also using RuntimeValueProvider parameters for the bigtable ids (project_id, instance_id, cluster_id, etc), so I feel whatever type of transform I do to scale I'll need to use a DoFn.

Any help would be much appreciated!

Upvotes: 1

Views: 2249

Answers (2)

SamsonStannus
SamsonStannus

Reputation: 53

So I came up with a hacky approach, but it works.

During the setup() of my WriteFn I get the clusters.serve_nodes count (this will obviously change after the first worker calls setup()) and scale out the cluster if it's not a the desired count. And in the process() function I yield this count. I then do a beam.CombineGlobally and find the Smallest(1) of those counts. I then pass this to another DoFn that scales the cluster to that minimal count.

Here's some code snippets of what I'm doing.

class _BigTableWriteFn(beam.DoFn):
    """ Creates the connector can call and add_row to the batcher using each
    row in beam pipe line
    """
    def __init__(self, project_id, instance_id, table_id, cluster_id, node_count):
        """ Constructor of the Write connector of Bigtable
        Args:
        project_id(str): GCP Project of to write the Rows
        instance_id(str): GCP Instance to write the Rows
        table_id(str): GCP Table to write the `DirectRows`
        cluster_id(str): GCP Cluster to write the scale
        node_count(int): Number of nodes to scale to before writing
        """
        beam.DoFn.__init__(self)
        self.beam_options = {
            'project_id': project_id,
            'instance_id': instance_id,
            'table_id': table_id,
            'cluster_id': cluster_id,
            'node_count': node_count
        }
        self.table = None
        self.current_node_count = None
        self.batcher = None
        self.written = Metrics.counter(self.__class__, 'Written Row')

    def __getstate__(self):
        return self.beam_options

    def __setstate__(self, options):
        self.beam_options = options
        self.table = None
        self.current_node_count = None
        self.batcher = None
        self.written = Metrics.counter(self.__class__, 'Written Row')

    def setup(self):
        client = Client(project=self.beam_options['project_id'].get(), admin=True)
        instance = client.instance(self.beam_options['instance_id'].get())
        cluster = instance.cluster(self.beam_options['cluster_id'].get())
        cluster.reload()
        desired_node_count = self.beam_options['node_count'].get()
        self.current_node_count = cluster.serve_nodes
        if desired_node_count != self.current_node_count:
            cluster.serve_nodes = desired_node_count
            cluster.update()

    def start_bundle(self):
        if self.table is None:
            client = Client(project=self.beam_options['project_id'].get())
            instance = client.instance(self.beam_options['instance_id'].get())
            self.table = instance.table(self.beam_options['table_id'].get())

        self.batcher = self.table.mutations_batcher()

    def process(self, row):
        self.written.inc()
        # You need to set the timestamp in the cells in this row object,
        # when we do a retry we will mutating the same object, but, with this
        # we are going to set our cell with new values.
        # Example:
        # direct_row.set_cell('cf1',
        #                     'field1',
        #                     'value1',
        #                     timestamp=datetime.datetime.now())
        self.batcher.mutate(row)
        # return the initial node count so we can find the minimum value and scale down BigTable latter
        if self.current_node_count:
            yield self.current_node_count

    def finish_bundle(self):
        self.batcher.flush()
        self.batcher = None


class _BigTableScaleNodes(beam.DoFn):

    def __init__(self, project_id, instance_id, cluster_id):
        """ Constructor of the Scale connector of Bigtable
        Args:
        project_id(str): GCP Project of to write the Rows
        instance_id(str): GCP Instance to write the Rows
        cluster_id(str): GCP Cluster to write the scale
        """
        beam.DoFn.__init__(self)
        self.beam_options = {
            'project_id': project_id,
            'instance_id': instance_id,
            'cluster_id': cluster_id,
        }
        self.cluster = None

    def setup(self):
        if self.cluster is None:
            client = Client(project=self.beam_options['project_id'].get(), admin=True)
            instance = client.instance(self.beam_options['instance_id'].get())
            self.cluster = instance.cluster(self.beam_options['cluster_id'].get())


    def process(self, min_node_counts):
        if len(min_node_counts) > 0 and self.cluster.serve_nodes != min_node_counts[0]:
            self.cluster.serve_nodes = min_node_counts[0]
            self.cluster.update()

def run():
    custom_options = PipelineOptions().view_as(CustomOptions)
    
    pipeline_options = PipelineOptions()

    p = beam.Pipeline(options=pipeline_options)
    (p
    | 'Query BigQuery' >> beam.io.Read(beam.io.BigQuerySource(query=QUERY, use_standard_sql=True))
    | 'Map Query Results to BigTable Rows' >> beam.Map(to_direct_rows)
    | 'Write BigTable Rows' >> beam.ParDo(_BigTableWriteFn(
        custom_options.bigtable_project_id, 
        custom_options.bigtable_instance_id, 
        custom_options.bigtable_table_id,
        custom_options.bigtable_cluster_id,
        custom_options.bigtable_node_count))
    | 'Find Global Min Node Count' >> beam.CombineGlobally(beam.combiners.Smallest(1))
    | 'Scale Down BigTable' >> beam.ParDo(_BigTableScaleNodes(
        custom_options.bigtable_project_id, 
        custom_options.bigtable_instance_id, 
        custom_options.bigtable_cluster_id))
    )

    result = p.run()
    result.wait_until_finish()

Upvotes: 3

bigbounty
bigbounty

Reputation: 17368

If you are running the dataflow job not as a template but as a jar in a VM or pod, then you can do this before and after the pipeline starts by executing bash commands from java. Refer this - https://stackoverflow.com/a/26830876/6849682

Command to execute -

gcloud bigtable clusters update CLUSTER_ID --instance=INSTANCE_ID --num-nodes=NUM_NODES

But if you are running as template then, the template file won't consider anything other than what's between pipeline start and end

Upvotes: 2

Related Questions