Reputation: 154504
Is there a standard method for debouncing Celery tasks?
For example, so that a task can be "started" multiple times, but will only be run once after some delay:
def debounce_task(task):
if task_is_queued(task):
return
task.apply_async(countdown=30)
Upvotes: 18
Views: 5049
Reputation: 2806
Here's a more filled out solution based off https://stackoverflow.com/a/28157498/4391298 but turned into a decorator and reaching into the Kombu connection pool to reuse your Redis counter.
import logging
from functools import wraps
# Not strictly required
from django.core.exceptions import ImproperlyConfigured
from django.core.cache.utils import make_template_fragment_key
from celery.utils import gen_task_name
LOGGER = logging.getLogger(__name__)
def debounced_task(**options):
"""Debounced task decorator."""
try:
countdown = options.pop('countdown')
except KeyError:
raise ImproperlyConfigured("Debounced tasks require a countdown")
def factory(func):
"""Decorator factory."""
try:
name = options.pop('name')
except KeyError:
name = gen_task_name(app, func.__name__, func.__module__)
@wraps(func)
def inner(*args, **kwargs):
"""Decorated function."""
key = make_template_fragment_key(name, [args, kwargs])
with app.pool.acquire_channel(block=True) as (_, channel):
depth = channel.client.decr(key)
if depth <= 0:
try:
func(*args, **kwargs)
except:
# The task failed (or is going to retry), set the
# count back to where it was
channel.client.set(key, depth)
raise
else:
LOGGER.debug("%s calls pending to %s",
depth, name)
task = app._task_from_fun(inner, **options, name=name + '__debounced')
@wraps(func)
def debouncer(*args, **kwargs):
"""
Debouncer that calls the real task.
This is the task we are scheduling."""
key = make_template_fragment_key(name, [args, kwargs])
with app.pool.acquire_channel(block=True) as (_, channel):
# Mark this key to expire after the countdown, in case our
# task never runs or runs too many times, we want to clean
# up our Redis to eventually resolve the issue.
channel.client.expire(key, countdown + 10)
depth = channel.client.incr(key)
LOGGER.debug("Requesting %s in %i seconds (depth=%s)",
name, countdown, depth)
task.si(*args, **kwargs).apply_async(countdown=countdown)
return app._task_from_fun(debouncer, **options, name=name)
return factory
Upvotes: 0
Reputation: 154504
Here's the solution I came up with: https://gist.github.com/wolever/3cf2305613052f3810a271e09d42e35c
And copied here, for posterity:
import time
import redis
def get_redis_connection():
return redis.connect()
class TaskDebouncer(object):
""" A simple Celery task debouncer.
Usage::
def debounce_process_corpus(corpus):
# Only one task with ``key`` will be allowed to execute at a
# time. For example, if the task was resizing an image, the key
# might be the image's URL.
key = "process_corpus:%s" %(corpus.id, )
TaskDebouncer.delay(
key, my_taks, args=[corpus.id], countdown=0,
)
@task(bind=True)
def process_corpus(self, corpus_id, debounce_key=None):
debounce = TaskDebouncer(debounce_key, keepalive=30)
corpus = Corpus.load(corpus_id)
try:
for item in corpus:
item.process()
# If ``debounce.keepalive()`` isn't called every
# ``keepalive`` interval (the ``keepalive=30`` in the
# call to ``TaskDebouncer(...)``) the task will be
# considered dead and another one will be allowed to
# start.
debounce.keepalive()
finally:
# ``finalize()`` will mark the task as complete and allow
# subsequent tasks to execute. If it returns true, there
# was another attempt to start a task with the same key
# while this task was running. Depending on your business
# logic, this might indicate that the task should be
# retried.
needs_retry = debounce.finalize()
if needs_retry:
raise self.retry(max_retries=None)
"""
def __init__(self, key, keepalive=60):
if key:
self.key = key.partition("!")[0]
self.run_key = key
else:
self.key = None
self.run_key = None
self._keepalive = keepalive
self.cxn = get_redis_connection()
self.init()
self.keepalive()
@classmethod
def delay(cls, key, task, args=None, kwargs=None, countdown=30):
cxn = get_redis_connection()
now = int(time.time())
first = cxn.set(key, now, nx=True, ex=countdown + 10)
if not first:
now = cxn.get(key)
run_key = "%s!%s" %(key, now)
if first:
kwargs = dict(kwargs or {})
kwargs["debounce_key"] = run_key
task.apply_async(args=args, kwargs=kwargs, countdown=countdown)
return (first, run_key)
def init(self):
self.initial = self.key and self.cxn.get(self.key)
def keepalive(self, expire=None):
if self.key is None:
return
expire = expire if expire is not None else self._keepalive
self.cxn.expire(self.key, expire)
def is_out_of_date(self):
if self.key is None:
return False
return self.cxn.get(self.key) != self.initial
def finalize(self):
if self.key is None:
return False
with self.cxn.pipeline() as pipe:
while True:
try:
pipe.watch(self.key)
if pipe.get(self.key) != self.initial:
return True
pipe.multi()
pipe.delete(self.key)
pipe.execute()
break
except redis.WatchError:
continue
return False
Upvotes: 1
Reputation: 5834
bartek has the idea, use redis counters which are atomic (and should be easily available if your broker is redis). Although his solution is thottling, not debouncing. The difference is minor though (getset vs decr).
Queue up the task:
conn = get_redis()
conn.incr(key)
task.apply_async(args=args, kwargs=kwargs, countdown=countdown)
Then in the task:
conn = get_redis()
counter = conn.decr(key)
if counter > 0:
# task is still queued
return
# continue on to rest of task
It's hard to make it a decorator since you need to decorate the task and calling the task itself. So you will need a decorator before the celery @task decorator and one after it.
For now I'm just made some functions that help me call the task, and one that checks in the start of the task.
Upvotes: 4
Reputation: 8865
Here's how you can do it with Mongo.
NOTE: I had to make the design a little more forgiving, as Celery tasks aren't guaranteed to execute the exact moment of eta
is met or countdown
runs out.
Also, Mongo expiring indexes are only cleaned up every minute or so; So you can't base the design around records being deleted the moment the eta
is up.
Anyhow, the flow is something like this:
my_task
.preflight
increments a call counter, and returns it as flight_id
_my_task
is set to be executed after TTL
seconds._my_task
runs, it checks if it's flight_id
is still current. If it's not, it aborts.@celery.task(track_started=False, ignore_result=True)
def my_task(my_arg):
flight_id = preflight(inflight_collection, 'my_task', HASH(my_arg), TTL)
_my_task.apply_async((my_arg,), {'flight_id':flight_id}, countdown=TTL)
@celery.task(track_started=False, ignore_result=True)
def _my_task(my_arg, flight_id=None):
if not check_for_takeoff(inflight_collection, 'my_task', HASH(my_arg), flight_id):
return
# ... actual work ... #
Library code:
TTL = 5 * 60 # Run tasks after 5 minutes
EXPIRY = 6 * TTL # This needs to be much larger than TTL.
# We need to store a list of task-executions currently pending
inflight_collection = db['celery_In_Flight']
inflight_collection.create_index([('fn', pymongo.ASCENDING,),
('key', pymongo.ASCENDING,)])
inflight_collection.create_index('eta', expiresAfterSeconds=EXPIRY)
def preflight(collection, fn, key, ttl):
eta = datetime.datetime.now() + datetime.timedelta(seconds=ttl)
result = collection.find_one_and_update({
'fn': fn,
'key': key,
}, {
'$set': {
'eta': eta
},
'$inc': {
'flightId': 1
}
}, upsert=True, return_document=pymongo.ReturnDocument.AFTER)
print 'Preflight[{}][{}] = {}'.format(fn, key, result['flightId'])
return result['flightId']
def check_for_takeoff(collection, fn, key, flight_id):
result = collection.find_one({
'fn': fn,
'key': key
})
ready = result is None or result['flightId'] == flight_id
print 'Check[{}][{}] = {}, {}'.format(fn, key, result['flightId'], ready)
return ready
Upvotes: 1
Reputation: 15609
Here's how we do it with Redis counters. All of this can probably be generalized in a decorator but we only use it for a specific task (webhooks)
Your public-facing task is what you call from other functions. It'll need to increment a key in Redis. The key is formed by the arguments of your function, whatever they may be (this ensures the counter is unique amongst individual tasks)
@task
def your_public_task(*args, **kwargs):
cache_key = make_public_task_cache_key(*args, **kwargs)
get_redis().incr(cache_key)
_your_task(*args, **kwargs, countdown=settings.QUEUE_DELAY)
Note the cache key functions are shared (you want the same cache key in each function), and the countdown
setting.
Then, the actual task executing the code does the following:
@task
def _your_task(*args, **kwargs):
cache_key = make_public_task_cache_key(*args, **kwargs)
counter = get_redis().getset(cache_key, 0)
# redis makes the zero a string.
if counter == '0':
return
... execute your actual task code.
This lets you hit your_public_task.delay(..)
as many times as you want, within your QUEUE_DELAY
, and it'll only fire off once.
Upvotes: 14