Reputation: 63
Question
How to stop chain execution for a task that fails somewhere in the middle of a chain?
Example
@app.task
def ok(i):
print(f"ok = {i}")
return i + 1
@app.task
def fail(i):
print(f"fail = {i}")
raise RuntimeError(str(i))
if __name__ == "__main__":
fails_in_the_middle = (ok.s(1) | fail.s() | ok.s())
fails_in_the_end = (ok.s(1) | fail.s())
for c in [fails_in_the_end, fails_in_the_middle]:
resp = c.delay()
try:
resp.get(timeout=5)
except RuntimeError:
print("runtime error (as expected)")
except celery.TimeoutError:
# QUESTION: How do I make this never happen for `fails_in_the_middle`?
print("timed out (not expected)")
Observations
fails_in_the_middle
chain: resp.get(timeout=2)
always results in TimeoutError
;fails_in_the_end
: always raises RuntimeError(2)
and doesn't wait for timeout.It looks like AsyncResult.get()
blocks forever if there are pending tasks in chain, but cleaning up the chain by setting self.request.chain = None
does not help, the chain is still blocked until timeout is reached.
Behavior above is present when Redis is used both as a message broker and result storage. When tasks are marked as eager (executed within current process), the behavior is as expected.
Would be grateful for any insights. Thank you!
Upvotes: 2
Views: 1656
Reputation: 63
Got the answer from celery-users group, hopefully it will help someone else. All the credits go to Ing. Josue Balandrano Coronel.
There's a caveat when getting the result the way you're doing it. A chain is a bunch of tasks linked together, when you do
resp = c.delay()
you are queuing all the tasks in the chainThe object that
c.delay()
returns is not a pointer to the entire chain but a pointer to the last task int he chain. Meaning,resp
ends up with a pointer to the result of the last task in the chain, in this case the secondok.s()
. This means that if the middle task fails and you try to doresp.get()
it will timeout because the task thatresp
is pointing to never actually got executed.This also explains the behavior you're seeing for
fails_in_the_end
becauseresp
is a pointer tofail.s()
when you doresp.get()
Celery will raise the error that happened in the task. This is how Celery handles accessing results of tasks which raised errors and there's no error handling in the task.Now, depending on what you actually want to do there's a few different options you can do. In order of recommendation.
1) You can add success and error handlers to the chain to make sure you're not working with an incomplete chain.
@app.task(bind=True)
def success(self, result):
print(f"Chain result: ${result}")
print(f"Chain: ${self.chain}")
@app.task(bind=True)
def error(self, *args, **kwargs)
print(f"args: ${args)")
print(f"kwargs: ${kwargs}")
if __name__ == "__main__":
fails_in_the_middle = (ok.s(1) | fail.s() | ok.s() | success.s()).on_error(error.s())
fails_in_the_end = (ok.s(1) | fail.s() | success.s()).on_error(error.s())
2) Access chain results as a graph using
resp.collect(intermediate=True)
collect()
returns all the results in the chain as a directed acyclical graph (DAG). It's usually easier to transform this into a list so you can loop through each one of the results in order:
resp = c.delay()
for result in list(resp.collect(intermediate=True)):
print(result.get())
3) Walk the chain graph backwards until you get to the parent and loop through the children to get the results. Since what you end up with in
resp
is actually the last task in the chain you can walk the graph backwards until the first task in the chain and get the results from there:
resp = c.delay()
parent = resp.parent
while parent is not None:
parent = resp.parent
for child in parent.children:
child.get()
Upvotes: 2