andos
andos

Reputation: 106

Pull xcom from taskflow task to sensor inside a task group class instance

I’ve created a task group class that contains a common pattern of tasks I need to instantiate several times in my DAG. I have a task that submits a job request to an API and returns the API response, and then I need to extract the job id from that response and pass this to a sensor which will poll a separate API until the job has completed. However, I can’t get the XCOM pull syntax right.

I believe this is because I need to prefix the task_id I’m pulling from with the group_id but this group_id will be different in each instance of the class I create. Can anyone help me fix the syntax for my xcom pull so it is passing in the group_id of the current task_instance?

class MyTaskGroup(TaskGroup):
    def __init__(self, database_connection):
        super().__init__(group_id=f'register_schemas{database_connection['name']}', tooltip=f'Registering Schemas')

        @task(task_group=self)
        def refresh_schema_connections(database_connection):
            client = get_api_client()
            return client.submit_job(database_connection)
        
        schema_job = refresh_schema_connections(database_connection['id'])
        
        wait_for_schema_job_completion = MyJobSensor(
            task_id=f"wait_for_schema_job_completion", 
            task_group=self,
            mode="reschedule",
            job_id="{{ task_instance.xcom_pull(task_instance.task_group_id + '.refresh_schema_connections')['output']['id'] }}"
        )
        schema_job >> wait_for_schema_job_completion

I've tried a few variations but I can't find the right attribute to get the group id:

Upvotes: 1

Views: 502

Answers (1)

andos
andos

Reputation: 106

I couldn't get this to work this way. In the end I wrote a function to extract the specific field I need from the response of the previous task

@task(task_group=self)
def extract_job_id(response):
    return response['id']

wait_for_schema_job_completion = MyJobSensor(
    task_id=f"wait_for_schema_job_completion", 
    task_group=self,
    mode="reschedule",
  job_id=extract_job_id(refresh_schema_connections(get_first_conn(database_connections)))
)
     

Upvotes: 1

Related Questions