Vineet
Vineet

Reputation: 783

Run Multiple Athena Queries in Airflow 2.0

I am trying to create a DAG in which one of the task does athena query using boto3. It worked for one query however I am facing issues when I try to run multiple athena queries.

This problem can be broken as follows:-

  1. If one goes through this blog, it can be seen that athena uses start_query_execution to trigger query and get_query_execution for getting status, queryExecutionId and other data about the query (docs for athena)

After following the above pattern I have following code:-

import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator


def execute_query(client, query, database, output_location):
    response = client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': database
        },
        ResultConfiguration={
            'OutputLocation': output_location
        }
    )

    return response['QueryExecutionId']


async def get_ids(client_athena, query, database, output_location):
    query_responses = []
    for i in range(5):
        query_responses.append(execute_query(client_athena, query, database, output_location))    

    res = await asyncio.gather(*query_responses, return_exceptions=True)

    return res

def run_athena_query(query, database, output_location, region_name, **context):
    BOTO_SESSION = boto3.Session(
        aws_access_key_id = 'YOUR_KEY',
        aws_secret_access_key = 'YOUR_ACCESS_KEY')
    client_athena = BOTO_SESSION.client('athena', region_name=region_name)

    loop = asyncio.get_event_loop()
    query_execution_ids = loop.run_until_complete(get_ids(client_athena, query, database, output_location))
    loop.close()

    repetitions = 900
    error_messages = []
    s3_uris = []

    while repetitions > 0 and len(query_execution_ids) > 0:
        repetitions = repetitions - 1
        
        query_response_list = client_athena.batch_get_query_execution(
            QueryExecutionIds=query_execution_ids)['QueryExecutions']
      
        for query_response in query_response_list:
            if 'QueryExecution' in query_response and \
                    'Status' in query_response['QueryExecution'] and \
                    'State' in query_response['QueryExecution']['Status']:
                state = query_response['QueryExecution']['Status']['State']

                if state in ['FAILED', 'CANCELLED']:
                    error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
                    error_message = 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
                            state, query_execution_id, error_message
                        )
                    error_messages.append(error_message)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                
                elif state == 'SUCCEEDED':
                    result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
                    s3_uris.append(result_location)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                 
                    
        time.sleep(2)
    
    logging.exception(error_messages)
    return s3_uris


DEFAULT_ARGS = {
    'owner': 'ubuntu',
    'depends_on_past': True,
    'start_date': datetime(2021, 6, 8),
    'retries': 0,
    'concurrency': 2
}

with DAG('resync_job_dag', default_args=DEFAULT_ARGS, schedule_interval=None) as dag:

    ATHENA_QUERY = PythonOperator(
        task_id='athena_query',
        python_callable=run_athena_query,
        provide_context=True,
        op_kwargs={
            'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;', # query provide in athena tutorial
            'database':'sampledb',
            'output_location':'YOUR_BUCKET',
            'region_name':'YOUR_REGION'
        }
    )

    ATHENA_QUERY

On running above code, I am getting following error:-

[2021-06-16 20:34:52,981] {taskinstance.py:1455} ERROR - An asyncio.Future, a coroutine or an awaitable is required
Traceback (most recent call last):
  File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1112, in _run_raw_task
    self._prepare_and_execute_task_with_callbacks(context, task)
  File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1285, in _prepare_and_execute_task_with_callbacks
    result = self._execute_task(context, task_copy)
  File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1315, in _execute_task
    result = task_copy.execute(context=context)
  File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/operators/python.py", line 117, in execute
    return_value = self.execute_callable()
  File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/operators/python.py", line 128, in execute_callable
    return self.python_callable(*self.op_args, **self.op_kwargs)
  File "/home/ubuntu/iac-airflow/dags/helper/tasks.py", line 93, in run_athena_query
    query_execution_ids = loop.run_until_complete(get_ids(client_athena, query, database, output_location))
  File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
    return future.result()
  File "/home/ubuntu/iac-airflow/dags/helper/tasks.py", line 79, in get_ids
    res = await asyncio.gather(*query_responses, return_exceptions=True)
  File "/usr/lib/python3.6/asyncio/tasks.py", line 602, in gather
    fut = ensure_future(arg, loop=loop)
  File "/usr/lib/python3.6/asyncio/tasks.py", line 526, in ensure_future
    raise TypeError('An asyncio.Future, a coroutine or an awaitable is '
TypeError: An asyncio.Future, a coroutine or an awaitable is required

I am unable to get where I am going wrong. Would appreciate some hint over the issue

Upvotes: 3

Views: 3707

Answers (2)

Vineet
Vineet

Reputation: 783

Following as well worked for me. I just complicated simple problem with asyncio.

Since I needed S3 URIs for each query at last therefore I went for writing script from scratch. In the current implementation of AWSAthenaOperator, one can get the queryExecutionId and then do the remaining processing(i.e create another task) for getting S3 URI of CSV result file. This can add some overhead in terms of delay between two tasks(of getting queryExecutionId and retrieving S3 URI) along with added resource usuage.

Therefore I went for doing the complete operation in a single operator as follows:-

Code:-

import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator


def execute_query(client, query, database, output_location):
    response = client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': database
        },
        ResultConfiguration={
            'OutputLocation': output_location
        }
    )

    return response


def run_athena_query(query, database, output_location, region_name, **context):
    BOTO_SESSION = boto3.Session(
        aws_access_key_id = 'YOUR_KEY',
        aws_secret_access_key = 'YOUR_ACCESS_KEY')
    client_athena = BOTO_SESSION.client('athena', region_name=region_name)

    query_execution_ids = []
    if message_list:
        for parameter in message_list:
            query_response = execute_query(client_athena, query, database, output_location)
            query_execution_ids.append(query_response['QueryExecutionId'])
    else:
        raise Exception(
            'Error in upstream value recived from kafka consumer. Got message list as - {}, with type {}'
                .format(message_list, type(message_list))
        )


    repetitions = 900
    error_messages = []
    s3_uris = []

    while repetitions > 0 and len(query_execution_ids) > 0:
        repetitions = repetitions - 1
        
        query_response_list = client_athena.batch_get_query_execution(
            QueryExecutionIds=query_execution_ids)['QueryExecutions']
      
        for query_response in query_response_list:
            if 'QueryExecution' in query_response and \
                    'Status' in query_response['QueryExecution'] and \
                    'State' in query_response['QueryExecution']['Status']:
                state = query_response['QueryExecution']['Status']['State']

                if state in ['FAILED', 'CANCELLED']:
                    error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
                    error_message = 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
                            state, query_execution_id, error_message
                        )
                    error_messages.append(error_message)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                
                elif state == 'SUCCEEDED':
                    result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
                    s3_uris.append(result_location)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                 
                    
        time.sleep(2)
    
    logging.exception(error_messages)
    return s3_uris


DEFAULT_ARGS = {
    'owner': 'ubuntu',
    'depends_on_past': True,
    'start_date': datetime(2021, 6, 8),
    'retries': 0,
    'concurrency': 2
}

with DAG('resync_job_dag', default_args=DEFAULT_ARGS, schedule_interval=None) as dag:

    ATHENA_QUERY = PythonOperator(
        task_id='athena_query',
        python_callable=run_athena_query,
        provide_context=True,
        op_kwargs={
            'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;', # query provide in athena tutorial
            'database':'sampledb',
            'output_location':'YOUR_BUCKET',
            'region_name':'YOUR_REGION'
        }
    )

    ATHENA_QUERY

However, the approach shared by @Elad is more clean and apt if one wants to get queryExecutionIds of all the queries.

Upvotes: 0

Elad Kalif
Elad Kalif

Reputation: 15931

I think what you are doing here isn't really needed. Your issues ares:

  1. Executing multiple queries in parallel.
  2. Being able to recover queryExecutionId per query.

Both issues are solved simply by using AWSAthenaOperator. The operator already handles everything you mentioned for you.

Example:

from airflow.models import DAG
from airflow.utils.dates import days_ago
from airflow.operators.dummy import DummyOperator
from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator


with DAG(
    dag_id="athena",
    schedule_interval='@daily',
    start_date=days_ago(1),
    catchup=False,
) as dag:

    start_op = DummyOperator(task_id="start_task")
    query_list = ["SELECT 1;", "SELECT 2;" "SELECT 3;"]

    for i, sql in enumerate(query_list):
        run_query = AWSAthenaOperator(
            task_id=f'run_query_{i}',
            query=sql,
            output_location='s3://my-bucket/my-path/',
            database='my_database'
        )
        start_op >> query_op

Athena tasks will be created dynamically simply by adding more queries to query_list:

enter image description here

Note that the QueryExecutionId is pushed to xcom thus you can access the in a downstream task if needed.

Upvotes: 2

Related Questions