inspectorG4dget
inspectorG4dget

Reputation: 113905

Passing task outputs with AirFlow XCOM

I have an SSHOperator that writes a filepath to stdout. I'd like to get the os.path.basename of that filepath so that I can pass it as a parameter to my next task (which is an sftp pull). The idea is to download a remote file into the current working directory. This is what I have so far:

with DAG('my_dag',
         default_args = dict(...
                             xcom_push = True,
                             )
         ) as dag:

    # there is a get_update_id task here, which has been snipped for brevity

    get_results  = SSHOperator(task_id = 'get_results',
                               ssh_conn_id = 'my_remote_server',
                               command = """cd ~/path/to/dir && python results.py -t p -u {{ task_instance.xcom_pull(task_ids='get_update_id') }}""",
                               cmd_timeout = -1,
                               )

    download_results = SFTPOperator(task_id = 'download_results',
                                    ssh_conn_id = 'my_remote_server',
                                    remote_filepath = base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }}"""),
                                    local_filepath = os.path.basename(base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }""").decode()),
                                    operation = 'get',
                                    )

Airflow tells me there's an error on the remote_filepath = line. Investigating this further, I see that the value passed to base64.b64decode is not the xcom value from the get_results task, but is rather the raw string starting with {{.

My feeling is that since tasks are templated, there's some under-the-hood magic to resolve the templated string. Whereas this is not exactly supported by os.path.basename. So would I need to create an intermediate task to get the basename? Is there no way to shorthand this the way I've tried?

I'd appreciate any help on this

Upvotes: 0

Views: 1706

Answers (1)

Oluwafemi Sule
Oluwafemi Sule

Reputation: 38922

You want to decode the XCOM return value when Airflow renders the remote_filepath property for the Task instance.

This means that the b64decode function must be invoked within the template string. There is a catch though, we have to make this function available in the template context by providing it as a parameter or on the DAG level as a user defined filter or macro.

def basename_b64decode(value):
    return os.path.basename(base64.b64decode(value)).decode()

download_results = SFTPOperator(
    task_id = 'download_results',
    ssh_conn_id = 'my_remote_server',
    remote_filepath = """{{params.b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    local_filepath = """{{params.basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    operation = 'get',
    params = {
       'b64decode': base64.b64decode
       'basename_b64decode': basename_b64decode
    }
)

For the DAG user-defined macro approach, you can write:


with DAG('my_dag',
         default_args = dict(...
                             xcom_push = True,
                             user_defined_macros=dict(
                                 basename_b64decode=basename_b64decode,
                                 b64decode=base64.b64decode
                             ) 
                        )
         ) as dag:

download_results = SFTPOperator(
    task_id = 'download_results',
    ssh_conn_id = 'my_remote_server',
    remote_filepath = """{{b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    local_filepath = """{{basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    operation = 'get',
)

Upvotes: 1

Related Questions