enmyj
enmyj

Reputation: 401

Airflow - how to set task dependencies between iterations of a for loop?

I am using Airflow to run a set of tasks inside for loop. The purpose of the loop is to iterate through a list of database table names and perform the following actions:

for table_name in list_of_tables:
    if table exists in database (BranchPythonOperator)
        do nothing (DummyOperator)
    else:
        create table (JdbcOperator)
    insert records into table (JdbcOperator, Trigger on One Success)

On the Web UI, this appears like:

tasks in for loop

Currently, Airflow executes the tasks in this image from top to bottom then left to right, like: tbl_exists_fake_table_one --> tbl_exists_fake_table_two --> tbl_create_fake_table_one, etc.

However, the insert statement for fake_table_two depends on fake_table_one being updated, a dependency not captured by Airflow currently. (Technically this dependency is captured by the order of the list_of_table_names, but I believe this will be prone to error in a more complex situation)

I want all tasks related to fake_table_one to run, followed by all tasks related to fake_table_two. How can I accomplish this in Airflow?

Full code below:

for tbl_name in list_of_table_names:

    # Check if table exists by querying information tables
    def has_table(tbl_name=tbl_name):
        p = JdbcHook('conn_id')
        sql =""" select count(*) from system.tables where name = '{}' """.format(tbl_name.upper())
        count = p.get_records(sql)[0][0] #unpack the list/tuple

        # If the query didn't return rows, branch to Create Table Task
        # otherwise, branch to Dummy Operator (Airflow requires that both branches have a task)
        if count == 0:
            return 'tbl_create_{}'.format(tbl_name)
        else:
            return 'dummy_{}'.format(tbl_name) 

    # run has_table python function
    exists = BranchPythonOperator(
        task_id='tbl_exists_{}'.format(tbl_name),
        python_callable=has_table,
        depends_on_past=False,
        dag=dag
    )

    # Dummy Operator
    dummy = DummyOperator(task_id='dummy_{}'.format(tbl_name),dag=dag,depends_on_past=False)

    # Run create table SQL script
    create = JdbcOperator(
        task_id='tbl_create_{}'.format(tbl_name),
        jdbc_conn_id='conn_id',
        sql = sql_parse(script_path, 'sql/sql_create/{}.sql'.format(tbl_name)), 
        depends_on_past=False,
        dag = dag
    )

    # Run insert or truncate/replace SQL script
    upsert = JdbcOperator(
        task_id='tbl_upsert_{}'.format(tbl_name),
        jdbc_conn_id='conn_id',
        sql = sql_parse(script_path, 'sql/sql_upsert/{}.sql'.format(tbl_name)),
        trigger_rule=TriggerRule.ONE_SUCCESS,
        dag = dag
    )

    # Set dependencies
    exists >> create >> upsert 
    exists >> dummy >> upsert

Upvotes: 3

Views: 9422

Answers (1)

jhnclvr
jhnclvr

Reputation: 9487

Store a reference to the last task added at the end of each loop. Then, at the beginning of each loop, check if the ref exists. If the ref exists, then set it upstream.

Something like this:

last_task = None

for tbl_name in list_of_table_names:


    # run has_table python function
    exists = BranchPythonOperator(
        task_id='tbl_exists_{}'.format(tbl_name),
        python_callable=has_table,
        depends_on_past=False,
        dag=dag
    )

    if last_task:
        last_task >> exists


    ...


    # Run insert or truncate/replace SQL script
    upsert = JdbcOperator(
        task_id='tbl_upsert_{}'.format(tbl_name),
        jdbc_conn_id='conn_id',
        sql = sql_parse(script_path, 'sql/sql_upsert/{}.sql'.format(tbl_name)),
        trigger_rule=TriggerRule.ONE_SUCCESS,
        dag = dag
    )

    last_task = upsert

    ...

Upvotes: 7

Related Questions