AlphaBetaGamma
AlphaBetaGamma

Reputation: 1940

Pass dynamic args to DataprocSubmitJobOperator from xcom

I am trying to receive an event from pub/sub and based on the message, it should pass some arguments to my dataproc spark job.

Now job_args in below code is dictionary. I have managed to push job_args as dictionary to xcom from python callable create_args_from_event, BUT the problem is when I try to use xcom_pull in my DAG it returns it as string while the DataprocSubmitJobOperator accepts it as dictionary object. Any way I can use it as created and pushed in xcom? The goal is to pull dictionary object from xcom and pass it to DataprocSubmitJobOperator

I have already tried render_template_as_native_obj=True and difference is that when I print it in another Python callable its class comes out to be dictionary but not sure how to use it here

dag = DAG(dag_id=dag_id, schedule_interval=None, default_args=default_args,render_template_as_native_obj=True)
with dag:        
    t1 = PubSubPullSensor(task_id='pull-messages',
                              project="projectname",
                              ack_messages=True,
                              max_messages=1,
                              subscription="subscribtionname")

    message = "{{ task_instance.xcom_pull() }}"
    t2 = PythonOperator(
            task_id='define_args',
            python_callable=create_args_from_event,
            op_kwargs={'var': message},
            provide_context=True,
    )
    job_args = "{{ task_instance.xcom_pull(task_ids='define_args', 
     key='define_args') }}"
    
    submit_job = {
        "reference": {"project_id": v_project_id},
        "placement": {"cluster_name": v_cluster_name},
        "spark_job": job_args["gcs_job"]
    }
    
    
    spark_job_submit = DataprocSubmitJobOperator(
            task_id="XXXX",
            job=submit_job,
            location="us-central1",
            gcp_conn_id=v_conn_id,
            project_id=v_project_id
        ) ```


Expected dictionary already created as below

    job_args = {
            "gcs_job": {
                "args": ["--foo=bar", "--foo2=bar2"],
                "jar_file_uris": ["gs://...."],
                "main_class": "com.xyz.something"
            }
        }

Upvotes: 1

Views: 1410

Answers (2)

Mazlum Tosun
Mazlum Tosun

Reputation: 6572

If I correctly understood your need, you want passing the job args with xcom to the DataprocSubmitJobOperator.

In this case, you can create a custom operator that overrides DataprocSubmitJobOperator, example :

from typing import Dict, Optional, Union, Sequence, Tuple

from airflow.providers.google.cloud.operators.dataproc import DataprocSubmitJobOperator
from google.api_core.gapic_v1.method import _MethodDefault, DEFAULT
from google.api_core.retry import Retry


class CustomDataprocSubmitJobOperator(DataprocSubmitJobOperator):

    def __init__(
            self,
            job: Dict,
            region: str,
            project_id: Optional[str] = None,
            request_id: Optional[str] = None,
            retry: Union[Retry, _MethodDefault] = DEFAULT,
            timeout: Optional[float] = None,
            metadata: Sequence[Tuple[str, str]] = (),
            gcp_conn_id: str = "google_cloud_default",
            impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
            asynchronous: bool = False,
            cancel_on_kill: bool = True,
            wait_timeout: Optional[int] = None) -> None:
        super(CustomDataprocSubmitJobOperator, self) \
            .__init__(
            job=job,
            region=region,
            project_id=project_id,
            request_id=request_id,
            retry=retry,
            timeout=timeout,
            metadata=metadata,
            gcp_conn_id=gcp_conn_id,
            impersonation_chain=impersonation_chain,
            asynchronous=asynchronous,
            cancel_on_kill=cancel_on_kill,
            wait_timeout=wait_timeout)

    def execute(self, context):
        task_instance = context['task_instance']

        # Retrieve job args from xcom pull
        job_args = task_instance.xcom_pull(task_ids='define_args', key='define_args')
       
        # Apply a transformation on job args if needed, maybe String to Dict
        expected_job_args = .......

        # Set the transformed jobs args to the expected field.
        self.job = expected_job_args

        super(CustomDataprocSubmitJobOperator, self).execute(context)

Some explanations :

  • I created a class called CustomDataprocSubmitJobOperator that overrides DataprocSubmitJobOperator
  • In the execute method of the operator, I have access to the current context. Via this context, I can recover the job args params with xcom pull
  • If it's needed, I can apply a transformation on the job args (maybe transform String to Dict)
  • At the end I can set the transformed job args to the expected field in the operator, job in my example

Upvotes: 1

Michal Volešíni
Michal Volešíni

Reputation: 111

The problem is that you trying to use {{ task_instance.xcom_pull() }} outside task instance, so outside of task instance, your job_args variable is just a string so using your job_args["gcs_job"] wont work.

You can either write a custom operator that inherits from original operator and do some tweaks to read xcom value as Mazlum Tosun proposed or you can create another PythonOperator that will transform your value and push it to xcom so you can read that value in DataprocSubmitJobOperator

def f_generate_job_args(ti, **kwargs):
    project_id = kwargs['project_id']
    cluster_name = kwargs['cluster_name']

    xcom_results = ti.xcom_pull(task_ids='define_args', key='define_args')
    
    submit_job = {
        "reference": {"project_id": project_id},
        "placement": {"cluster_name": cluster_name},
        "spark_job": xcom_results["gcs_job"]
    }

    return submit_job

dag = DAG(dag_id=dag_id, schedule_interval=None, default_args=default_args,render_template_as_native_obj=True)
with dag:        
    t1 = PubSubPullSensor(task_id='pull-messages',
                              project="projectname",
                              ack_messages=True,
                              max_messages=1,
                              subscription="subscribtionname")

    message = "{{ task_instance.xcom_pull() }}"
    t2 = PythonOperator(
            task_id='define_args',
            python_callable=create_args_from_event,
            op_kwargs={'var': message},
            provide_context=True,
    )
    

    generate_job_args = PythonOperator(
            task_id='generate_job_args',
            python_callable=f_generate_job_args,
            op_kwargs={'project_id': v_project_id, 'cluster_name': v_cluster_name'}
            )            
    
    spark_job_submit = DataprocSubmitJobOperator(
            task_id="XXXX",
            job={{ti.xcom_pull(task_ids='generate_job_args')}},
            location="us-central1",
            gcp_conn_id=v_conn_id,
            project_id=v_project_id
        )

    t1 >> t2 >> generate_job_args >> spark_job_submit

another option is to create whole submit job dict in your define_args task

t2 = PythonOperator(
            task_id='define_args',
            python_callable=create_args_from_event,
            op_kwargs={'var': message, 'project_id': v_project_id, 'cluster_name': v_cluster_name'},
            provide_context=True,
    )

and in your create_args_from_event either create a new key by xcom.push or prepare whole dict, depend on your needs.

then you can easily use

spark_job_submit = DataprocSubmitJobOperator(
            task_id="XXXX",
            job={{ti.xcom_pull(task_ids='define_args', key='my_new_generated_key')}},
            location="us-central1",
            gcp_conn_id=v_conn_id,
            project_id=v_project_id
        ) 

dont forget to set your render_template_as_native_obj=True

Upvotes: 0

Related Questions