soumeng78
soumeng78

Reputation: 880

How to pass both context for xcom and parameter to python_callable

I want to have a PythonOperator task which will accept an input parameter data_path and randomly select a csv file from the path and pass the randomly selected file to the subsequent task in the DAG.

For the later part (passing the selected file to subsequent task in the DAG), I want to use xcom_push. How do I need to write my callable and operator?

Without the xcom part, my rough is like following:

def _select_random_data(**kwargs):
    data_dir = kwargs.get('data_path')
    logger   = kwargs.get('logger')

    if not os.path.exists(data_dir):
        raise RuntimeError('data directory does not exist')

    pattern = f'{data_dir}/*.csv'
    random_filename = random.choice([x for x in glob.glob(pattern) if os.path.isfile(x)])

    logger.info('random file name: {}'.format(random_filename))

    #task_instance = kwargs['task_instance']
    #task_instance.xcom_push(key='file_name', value=random_filename)


 select_file_task = PythonOperator(task_id='select_file',
                                      python_callable=_select_random_data,
                                      provide_context=True,
                                      op_kwargs={
                                          'data_path': 'path1',
                                          'logger': logger,
                                      })

Upvotes: 0

Views: 1313

Answers (1)

soumeng78
soumeng78

Reputation: 880

I could make below code work with "airflow tasks test <dag_id> <task_id> <execution_date>" but the same workflow pipeline does not work when I run from webui - in ui, my DAG fails on select_file_task. I don't get any details in the logs.

def _select_file(data_dir, **kwargs):
    import logging
    logger = logging.getLogger("airflow.task")

    #LoggingMixin().log.info('input data path: {}'.format(data_dir))
    logger.info('input data path: {}'.format(data_dir))

    if not os.path.exists(data_dir):
        raise RuntimeError('data directory does not exist')

    pattern = f'{data_dir}/*.csv'
    random_filename = random.choice([x for x in glob.glob(pattern) if os.path.isfile(x)])

    logger.info('random file name: {}'.format(random_filename))

    ti = kwargs['ti']
    ti.xcom_push(key='file_name', value=random_filename)


def _print_filename(**kwargs):
    import logging
    logger = logging.getLogger("airflow.task")

    ti = kwargs['ti']
    file_name = ti.xcom_pull(task_ids='select_file', key='file_name')

    logger.info('selected file name: {}'.format(file_name))

with DAG('test_pipeline',
         default_args=default_args,
         description='A pipeline for experiments',
         schedule_interval='*/5 * * * *',
         start_date=days_ago(1),
         dagrun_timeout=timedelta(seconds=5)
         ) as dag:

    start_task = DummyOperator(task_id= "start")
    stop_task  = DummyOperator(task_id= "stop" )

    select_file_task = PythonOperator(task_id='select_file',
                                      python_callable=_select_file,
                                      provide_context=True,
                                      op_args=['/vol/algo1/validation_data']
                                      )

    print_file_task = PythonOperator(task_id='print_file',
                                     python_callable=_print_filename,
                                     provide_context=True
                                     )

    start_task >> select_file_task >> print_file_task >> stop_task

Upvotes: 0

Related Questions