cereal killer
cereal killer

Reputation: 63

How to stop chain execution for failing task in the middle of a chain?

Question

How to stop chain execution for a task that fails somewhere in the middle of a chain?

Example

@app.task
def ok(i):
    print(f"ok = {i}")
    return i + 1


@app.task
def fail(i):
    print(f"fail = {i}")
    raise RuntimeError(str(i))


if __name__ == "__main__":
    fails_in_the_middle = (ok.s(1) | fail.s() | ok.s())
    fails_in_the_end = (ok.s(1) | fail.s())

    for c in [fails_in_the_end, fails_in_the_middle]:
        resp = c.delay()
        try:
            resp.get(timeout=5)
        except RuntimeError:
            print("runtime error (as expected)")
        except celery.TimeoutError:
            # QUESTION: How do I make this never happen for `fails_in_the_middle`?
            print("timed out (not expected)")

Observations

It looks like AsyncResult.get() blocks forever if there are pending tasks in chain, but cleaning up the chain by setting self.request.chain = None does not help, the chain is still blocked until timeout is reached.

Behavior above is present when Redis is used both as a message broker and result storage. When tasks are marked as eager (executed within current process), the behavior is as expected.

Would be grateful for any insights. Thank you!

Upvotes: 2

Views: 1656

Answers (1)

cereal killer
cereal killer

Reputation: 63

Got the answer from celery-users group, hopefully it will help someone else. All the credits go to Ing. Josue Balandrano Coronel.

There's a caveat when getting the result the way you're doing it. A chain is a bunch of tasks linked together, when you do resp = c.delay() you are queuing all the tasks in the chain

The object that c.delay() returns is not a pointer to the entire chain but a pointer to the last task int he chain. Meaning, resp ends up with a pointer to the result of the last task in the chain, in this case the second ok.s(). This means that if the middle task fails and you try to do resp.get() it will timeout because the task that resp is pointing to never actually got executed.

This also explains the behavior you're seeing for fails_in_the_end because resp is a pointer to fail.s() when you do resp.get() Celery will raise the error that happened in the task. This is how Celery handles accessing results of tasks which raised errors and there's no error handling in the task.

Now, depending on what you actually want to do there's a few different options you can do. In order of recommendation.

1) You can add success and error handlers to the chain to make sure you're not working with an incomplete chain.

@app.task(bind=True)
def success(self, result):
    print(f"Chain result: ${result}")
    print(f"Chain: ${self.chain}")

@app.task(bind=True)
def error(self, *args, **kwargs)
    print(f"args: ${args)")
    print(f"kwargs: ${kwargs}")

if __name__ == "__main__":
    fails_in_the_middle = (ok.s(1) | fail.s() | ok.s() | success.s()).on_error(error.s())
    fails_in_the_end = (ok.s(1) | fail.s() | success.s()).on_error(error.s())

2) Access chain results as a graph using resp.collect(intermediate=True) collect() returns all the results in the chain as a directed acyclical graph (DAG). It's usually easier to transform this into a list so you can loop through each one of the results in order:

resp = c.delay()
for result in list(resp.collect(intermediate=True)):
    print(result.get())

3) Walk the chain graph backwards until you get to the parent and loop through the children to get the results. Since what you end up with in resp is actually the last task in the chain you can walk the graph backwards until the first task in the chain and get the results from there:

resp = c.delay()
parent = resp.parent

while parent is not None:
    parent = resp.parent

for child in parent.children:
    child.get() 

Upvotes: 2

Related Questions