Lena Meer
Lena Meer

Reputation: 89

how to pass query parameter to sql file using python operator in airflow

I have a python function that runs a query in BigQuery and creates a list from the results and then pushes this list as an XCOM.
The sql query is in a file and I want to use parameters to pass to that sql.
I want to be able to do something similar to what {{params.name}} allow for BigQueryOperator but with a PythonOperator.

The DAG runs successfully but it doesn't insert the DAG_NAME variable into the sql.
It just refers to the '{DAG_NAME}' as a static string.

Please help, stuck on it for a while.

My Code

from airflow.models import (DAG, Variable)
import os

from airflow.operators.dummy import DummyOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator

from datetime import datetime, date, timedelta
import json
import pandas as pd
from pathlib import Path

from google.cloud import storage
import gcsfs

from google.cloud import bigquery
from airflow.contrib.hooks.bigquery_hook import BigQueryHook

default_args = {
    'start_date': datetime(2020, 1, 1),
}

PROJECT_ID_GCP = Variable.get('project_GCP')
PROJECT_ID = Variable.get('project_BigQuery')
DATASET_MRR = Variable.get('LP_MRR')
DATASET_STG = Variable.get('LP_STG')
DATASET_DWH = Variable.get('LP_DWH')
DATASET_MNG = Variable.get('LP_MNG')
PARTITION_STATUS_TABLE = Variable.get('PARTITION_STATUS_TABLE')
MAPPING_TABLE = Variable.get('MAPPING_TABLE')


DAG_NAME = "my_dag_name"

gs = 'gs:/'
dir_name = '/my-bucket/dags/airflowsql/Global_Scripts'
filename = 'find_partitions'
suffix ='.sql'


def list_dates_in_df(ti, **kwargs):
    hook = BigQueryHook(bigquery_conn_id=PROJECT_ID,
                        use_legacy_sql=False)
    bq_client = bigquery.Client(project=hook._get_field("project"),
                                credentials=hook._get_credentials())

    fs = gcsfs.GCSFileSystem(project=PROJECT_ID_GCP)
    with fs.open(Path(gs,dir_name,filename).with_suffix((suffix)), 'r') as f:
        data = (f.read())

    df = bq_client.query(data).to_dataframe()
    res = df.values.tolist()
    bf_yesterday = (datetime.now() - timedelta(days=2)).strftime('%Y%m%d')
    default_list = [bf_yesterday]
    my_list_a = [item for l in res for item in l]
    my_list = [i.replace('-', '') for i in my_list_a]
    if my_list:
        ti.xcom_push(key='my_list', value=my_list)
    else:
        ti.xcom_push(key='my_list', value=default_list)

with DAG(
        'test_path',
        schedule_interval=None,
        catchup=False,
        default_args=default_args
) as dag:

    list_dates = PythonOperator(
        task_id='list_dates',
        python_callable=list_dates_in_df,
        op_kwargs = {'a' : DAG_NAME}
    )

    end_job = BashOperator(
        task_id='end_job',
        bash_command='echo end_job.',
        do_xcom_push=False
        )

    list_dates  >> end_job

The sql in the file:

 select distinct(cast(PARTITION_KEY as string)) as PARTITION_KEY
        FROM LP_MNG.PartitionStatusMonitoring
        where SOURCE_TABLE in (select distinct (MRR_SOURCE_TABLE) 
        FROM LP_MNG.TableMapping
        where DAG_NAME = '{DAG_NAME}')
        and IS_LOAD_COMPLETED = false;

Upvotes: 0

Views: 2837

Answers (1)

Lena Meer
Lena Meer

Reputation: 89

I've figured it out, I'm replacing the string saved in the variable by the op_kwargs key_value

so my code becomes

from airflow.models import (DAG, Variable)
import os

from airflow.operators.dummy import DummyOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator

from datetime import datetime, date, timedelta
import json
import pandas as pd
from pathlib import Path

from google.cloud import storage
import gcsfs

from google.cloud import bigquery
from airflow.contrib.hooks.bigquery_hook import BigQueryHook

default_args = {
    'start_date': datetime(2020, 1, 1),
}

PROJECT_ID_GCP = Variable.get('project_GCP')
PROJECT_ID = Variable.get('project_BigQuery')
DATASET_MRR = Variable.get('LP_MRR')
DATASET_STG = Variable.get('LP_STG')
DATASET_DWH = Variable.get('LP_DWH')
DATASET_MNG = Variable.get('LP_MNG')
PARTITION_STATUS_TABLE = Variable.get('PARTITION_STATUS_TABLE')
MAPPING_TABLE = Variable.get('MAPPING_TABLE')


DAG_NAME = "my_dag_name"

gs = 'gs:/'
dir_name = '/my-bucket/dags/airflowsql/Global_Scripts'
filename = 'find_partitions'
suffix ='.sql'


def list_dates_in_df(ti, **kwargs):
    hook = BigQueryHook(bigquery_conn_id=PROJECT_ID,
                        use_legacy_sql=False)
    bq_client = bigquery.Client(project=hook._get_field("project"),
                                credentials=hook._get_credentials())

    #### creating the character I want to replace with
    param = "'"+kwargs['a']+"'"

    fs = gcsfs.GCSFileSystem(project=PROJECT_ID_GCP)
    with fs.open(Path(gs,dir_name,filename).with_suffix((suffix)), 'r') as f:
        data = (f.read())

    #### replaceing the placeholder in SQL
    data = data.replace('DAG_NAME_PLACE',param ) 
       
    df = bq_client.query(data).to_dataframe()
    res = df.values.tolist()
    bf_yesterday = (datetime.now() - timedelta(days=2)).strftime('%Y%m%d')
    default_list = [bf_yesterday]
    my_list_a = [item for l in res for item in l]
    my_list = [i.replace('-', '') for i in my_list_a]
    if my_list:
        ti.xcom_push(key='my_list', value=my_list)
    else:
        ti.xcom_push(key='my_list', value=default_list)

with DAG(
        'test_path',
        schedule_interval=None,
        catchup=False,
        default_args=default_args
) as dag:

    list_dates = PythonOperator(
        task_id='list_dates',
        python_callable=list_dates_in_df,
        op_kwargs = {'a' : DAG_NAME}
    )

    end_job = BashOperator(
        task_id='end_job',
        bash_command='echo end_job.',
        do_xcom_push=False
        )

    list_dates  >> end_job

and my sql is:

select distinct(cast(PARTITION_KEY as string)) as PARTITION_KEY
         FROM LP_MNG.PartitionStatusMonitoring
         where SOURCE_TABLE in (select distinct (MRR_SOURCE_TABLE)
         FROM LP_MNG.TableMapping
         where DAG_NAME = DAG_NAME_PLACE)
         and IS_LOAD_COMPLETED = false;

Upvotes: 1

Related Questions