Justinas Marozas
Justinas Marozas

Reputation: 2692

Airflow : Run a task when some upstream is skipped by shortcircuit

I have a task that I'll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this task gets skipped as well. I don't want final task to get skipped as it has to report on DAG success.

To avoid it getting skipped I used trigger_rule='all_done', but it still gets skipped.

If I use BranchPythonOperator instead of ShortCircuitOperator final task doesn't get skipped. It would seem like branching workflow could be a solution, even though not optimal, but now final will not respect failures of upstream tasks.

How do I get it to only run when upstreams are successful or skipped?

Sample ShortCircuit DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'shortcircuit_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

def shortcircuit_fn():
    return randint(0, 1) == 1

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> short >> work >> final
task_1 >> task_2 >> final

DAG with shortcircuit operator

Sample Branch DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'branch_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')

def branch_fn():
    return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final

DAG with branch operator

Upvotes: 19

Views: 29040

Answers (7)

Alexander Shapkin
Alexander Shapkin

Reputation: 1262

This is classic BranchOperator pitfall. Solution is using TriggerRule:

trigger_rule='none_failed_min_one_success'

Task runs if all direct upstream tasks haven’t failed and at least one succeeded.

More info: Airflow Trigger Rules: All you need to know!

Upvotes: 0

nammerkage
nammerkage

Reputation: 304

The ShortCircuitOperator can now be configured to respect downstream task. Default behavior is not respecting it. You can make the operator repsect by setting ignore_downstream_trigger_rules=False.

task = ShortCircuitOperator(
    task_id='task_id',
    python_callable=function,
    ignore_downstream_trigger_rules=False,
)

Upvotes: 2

raphaelauv
raphaelauv

Reputation: 980

This question is still legit with airflow 1.10.X


The following solution work with airflow 1.10.X , not tested yet with airflow 2.X

ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set


The solution of @michael-spector will only work with simple case and not this case :

airflow dag exemple

with @michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )

A solution is this (based on @michael-spector proposition) :

class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
"""
Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task

So if a task have this task in his upstream AND another task it will not be skipped

        -> B -> C -> D ------\
      /                       \
A -> K                         -> Y
 \                            /
   -> F -> G - P -----------/


If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip

if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y


found_tasks_name contains the names of the previous skipped task
found_tasks contains the airflow_task_id of the previous skipped task

:return found_tasks
"""

def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
    if not found_tasks_to_skip:  # list of task_id to skip
        found_tasks_to_skip = []

    # necessary because found_tasks do not keep a copy of names but airflow task_id
    if not found_tasks_to_skip_names:
        found_tasks_to_skip_names = set()

    direct_relatives = task.get_direct_relatives(upstream=False)
    for t in direct_relatives:
        self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
        self.log.info(
            " Does all skipped task " +
            str(found_tasks_to_skip_names) +
            " contain the upstream tasks" +
            str(t.upstream_task_ids)
        )

        # if len == 1 then the task is only precede by a skipped task
        # otherwise check if ALL upstream task are skipped
        if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
            found_tasks_to_skip.append(t)
            found_tasks_to_skip_names.add(t.task_id)
            self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)

    return found_tasks_to_skip

def execute(self, context):
    condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
    self.log.info("Condition result is %s", condition)

    if condition:
        self.log.info('Proceeding with downstream tasks...')
        return

    self.log.info(
        'Skipping downstream tasks that only rely on this path...')

    tasks_to_skip = self.find_tasks_to_skip(context['task'])
    self.log.debug("Tasks to skip: %s", tasks_to_skip)

    if tasks_to_skip:
        self.skip(context['dag_run'], context['ti'].execution_date,
                  tasks_to_skip)

    self.log.info("Done.")

Upvotes: 3

Rob Knights
Rob Knights

Reputation: 247

I'm posting another possible workaround for this since this is a method that does not require a custom operator implementation.

I was influenced by the solution in this blog using a PythonOperator which raises an AirflowSkipException which skips the task itself and then downstream tasks individually.

https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/

This then respects the trigger_rule of the final downstream task, which in my case I set to trigger_rule='none_failed'.

Modfied example as per the blog to include a final task:

def fn_short_circuit(**context):
    if <<<some condition>>>:
        raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")

check_date = PythonOperator(
    task_id="check_if_min_date",
    python_callable=_check_date,
    provide_context=True,
    dag=dag,
)

task1 = DummyOperator(task_id="task1", dag=dag)
task2 = DummyOperator(task_id="task2", dag=dag)
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
final_task = DummyOperator(task_id="final_task",
    trigger_rule='none_failed',
    dag=dag)


task_1 >> short >> work >> final_task
task_1 >> task_2 >> final_task

Upvotes: 10

Justinas Marozas
Justinas Marozas

Reputation: 2692

I've made it work by making final task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.

# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule

def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
    """
    find directly upstream task instances and count how many are not in prefered statuses.
    return True if we got no instances with non-preferred statuses.
    """
    upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
    session = Session()
    query = (session
        .query(TaskInstance)
        .filter(
            TaskInstance.dag_id == dag.dag_id,
            TaskInstance.execution_date.in_([task_instance.execution_date]),
            TaskInstance.task_id.in_(upstream_task_ids)
        )
    )
    upstream_task_instances = query.all()
    unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
    print(unhappy_task_instances)
    return len(unhappy_task_instances) == 0

def final_fn(**context):
    """
    fail if upstream task instances have unwanted statuses
    """
    if not all_upstreams_either_succeeded_or_skipped(**context):
        raise AirflowException("Not all upstream tasks succeeded.")
    # Do things

# will run when upstream task instances are done, including failed
final = PythonOperator(
    dag=dag,
    task_id="final",
    trigger_rule=TriggerRule.ALL_DONE,
    python_callable=final_fn,
    provide_context=True)

Upvotes: 1

Geoff
Geoff

Reputation: 421

This may have been added after you asked your initial question, but Airflow now conveniently has a trigger_rule value of none_failed. If you set this on your final task, it should complete whether upstream tasks are skipped or succeeded, just not when they fail.

More info: https://airflow.apache.org/concepts.html#trigger-rules

Upvotes: -2

Michael Spector
Michael Spector

Reputation: 37019

I've ended up with developing custom ShortCircuitOperator based on the original one:

class ShortCircuitOperator(PythonOperator, SkipMixin):
    """
    Allows a workflow to continue only if a condition is met. Otherwise, the
    workflow "short-circuits" and downstream tasks that only rely on this operator
    are skipped.

    The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
    condition and short-circuits the workflow if the condition is False. Any
    downstream tasks that only rely on this operator are marked with a state of "skipped".
    If the condition is True, downstream tasks proceed as normal.

    The condition is determined by the result of `python_callable`.
    """

    def find_tasks_to_skip(self, task, found_tasks=None):
        if not found_tasks:
            found_tasks = []
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            if len(t.upstream_task_ids) == 1:
                found_tasks.append(t)
                self.find_tasks_to_skip(t, found_tasks)
        return found_tasks

    def execute(self, context):
        condition = super(ShortCircuitOperator, self).execute(context)
        self.log.info("Condition result is %s", condition)

        if condition:
            self.log.info('Proceeding with downstream tasks...')
            return

        self.log.info(
            'Skipping downstream tasks that only rely on this path...')

        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)

        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,
                      tasks_to_skip)

        self.log.info("Done.")

This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.

Upvotes: 16

Related Questions