Vladimir Medvedev
Vladimir Medvedev

Reputation: 11

How to correctly execute multiple Airflow operators sequentially in a for loop inside of a task based on input parameter

I have Airflow Dag that must run an external job for a specific period provided by the user in params. Dag has the following flow:

  1. List files in S3 bucket for the date range, verify that specific files are present and return a map {partition -> {start_date, end_date}}.
  2. Use the map returned on the first step to dynamically create multiple "run_job" tasks and run them in parallel by partition.
  3. Inside each "run_job" iterate over list of periods {start_date, end_date} and start for each a new DatabricksRunNowOperator sequentially and execute them one by one.

Below is the shortened version of the DAG:

with DAG(
        dag_id="backfill",
        start_date=datetime(2023, 6, 1),
        catchup=False,
        render_template_as_native_obj=True,
        max_active_tasks=10,
        schedule=None,
        params={
            "start_date": Param(DEFAULT_DATE_PARAM, format="date", type="string"),
            "end_date": Param(DEFAULT_DATE_PARAM, format="date", type="string"),
        },
) as backfill:
    @task
    def run_job(period_to_partition: (str, [str, str]), **context):
        tasks = []
        partition, periods = period_to_partition

        for period in periods:
            start_date = period.get("start_date")
            end_date = period.get("end_date")

            job_params = prepare_job_args(start_date=start_date, end_date=end_date)

            run_job = DatabricksRunNowOperator(
                job_id="job_id",
           task_id=f"run_dbx_job_{parition}_{start_date}_{end_date}",
                spark_submit_params=job_params,
                dag=backfill,
                trigger_rule=TriggerRule.ALWAYS,
                retries=3
            )
            tasks.append(run_job)

        chain(*tasks)

        for single_task in tasks:
            single_task.execute(context)


    @task
    def prepare_period(**context):
        """List files in S3 bucket for the date range, verify that specific files are present 
        and return a map {partition -> {start_date, end_date}}."""
        params = context["params"]
        return split_to_partition_to_periods(params["start_date"], params["end_date"], context)


    run_job.expand(period_to_partition=prepare_period())

The code above works fine under normal circumstances, but if one of the DatabricksRunNowOperator fails, then the next tasks in the array tasks are not executed and DAG is failed. It seems that the approach I have chosen for this task is wrong, but I don't see any other way.

QUESTION

Is there any way to run DatabricksRunNowOperator inside a task in a way, that if one operator fails, other can be executed next?

THINGS I HAVE TRIED

  1. I have tried removing the chaining of the operators in the loop, adding a "trigger_rule=ALL_DONE "for each operator, but the actions had no effect. It seems like the "trigger_rule", "retries" and other parameters of the operator are suppressed by the wrapper task parameters and are not taken into account.
  2. Getting data from S3 out of Airflow task is not possible, because to query S3 we need to know input parameters from the user i.e. "start_date", "end_date".
  3. I cannot move creation of the operator to another @task, because I will not be able to call that task from "run_job" task.
  4. I cannot use @task_groups or move the operators creation logic to DAG level because I need information about periods from an Airflow task in order to know how many operators to create sequentially.
  5. I have tried moving "trigger_rule", "retries" to DAG level as default for every task, but they seem to only have effect on the wrapper task and not the operators.

Upvotes: 1

Views: 747

Answers (1)

Vladimir Medvedev
Vladimir Medvedev

Reputation: 11

As it turns out, the DatabricksRunNowOperator executed inside of a @task was considered a standalone task without succeeding or preceding tasks, and its failure was considered as failure of the wrapper task as well.

In order to omit this issue, there were tried multiple approaches:

Approach of creating a specific pool for each partition with only one slot so that all tasks for specific partitions were executed one by one in order to omit ConcurrentModificationException. This way is not yet possible, because Airflow does not allow dynamic definition of pools: https://github.com/apache/airflow/issues/33657

Approach of catching and muting the airflow exception manually inside of the for loop that chained execution of tasks for different partitions. This approach did not work, because if an exception is caught and muted, the retry and slack_alert mechanism of Airflow are not working as they do not receive a signal that the task failed.

SOLUTION

Finally, another mixed approach was chosen. The logic of creating a DatabricksRunNowOperator was moved to another DAG (run_dbx_job) and as a result, Airflow handled retries of DatabricksRunNowOperator as it was a separate task in a different DAG and the caller DAG (backfill) was catching the exception and proceeding with other periods in case the run_dbx_job failed to run retries as well.

CODE EXAMPLE

with DAG(
    dag_id="backfill",
    start_date=datetime(2023, 6, 1),
    catchup=False,
    render_template_as_native_obj=True,
    max_active_tasks=10,
    schedule=None,
    params={
        "start_date": Param(DEFAULT_DATE_PARAM, format="date", type="string"),
        "end_date": Param(DEFAULT_DATE_PARAM, format="date", type="string"),
    },
) as backfill:
@task
def run_job(period_to_partition: (str, [str, str]), **context):
    partition, periods = period_to_partition

    for period in periods:

        start_date = period.get("start_date")
        end_date = period.get("end_date")

        trigger = TriggerDagRunOperator(
            task_id=f"run_dbx_job_{partition}_{start_date}_{end_date}",
            trigger_dag_id="run_dbx_job",
            wait_for_completion=True,
            poke_interval=30,
            conf={"partition": partition, "start_date": start_date, "end_date": end_date},
        )
        try:
            trigger.execute(context=context)
        except AirflowException as exception:
            task_logger.error("TriggerDagRunOperator (id=%s) failed: %s", trigger.task_id, {str(exception)})


@task
def prepare_period(**context):
    """List files in S3 bucket for the date range, verify that specific files are present
    and return a map {partition -> {start_date, end_date}}."""
    params = context["params"]
    return split_to_partition_to_periods(params["start_date"], params["end_date"], context)


run_job.expand(period_to_partition=prepare_period())

Upvotes: 0

Related Questions