David Wolever
David Wolever

Reputation: 154504

Debounce Celery tasks?

Is there a standard method for debouncing Celery tasks?

For example, so that a task can be "started" multiple times, but will only be run once after some delay:

def debounce_task(task):
    if task_is_queued(task):
        return
    task.apply_async(countdown=30)

Upvotes: 18

Views: 5049

Answers (5)

Danielle Madeley
Danielle Madeley

Reputation: 2806

Here's a more filled out solution based off https://stackoverflow.com/a/28157498/4391298 but turned into a decorator and reaching into the Kombu connection pool to reuse your Redis counter.

import logging
from functools import wraps

# Not strictly required
from django.core.exceptions import ImproperlyConfigured
from django.core.cache.utils import make_template_fragment_key

from celery.utils import gen_task_name


LOGGER = logging.getLogger(__name__)


def debounced_task(**options):
    """Debounced task decorator."""

    try:
        countdown = options.pop('countdown')
    except KeyError:
        raise ImproperlyConfigured("Debounced tasks require a countdown")

    def factory(func):
        """Decorator factory."""
        try:
            name = options.pop('name')
        except KeyError:
            name = gen_task_name(app, func.__name__, func.__module__)

        @wraps(func)
        def inner(*args, **kwargs):
            """Decorated function."""

            key = make_template_fragment_key(name, [args, kwargs])
            with app.pool.acquire_channel(block=True) as (_, channel):
                depth = channel.client.decr(key)

                if depth <= 0:
                    try:
                        func(*args, **kwargs)
                    except:
                        # The task failed (or is going to retry), set the
                        # count back to where it was
                        channel.client.set(key, depth)
                        raise
                else:
                    LOGGER.debug("%s calls pending to %s",
                                depth, name)

        task = app._task_from_fun(inner, **options, name=name + '__debounced')

        @wraps(func)
        def debouncer(*args, **kwargs):
            """
            Debouncer that calls the real task.
            This is the task we are scheduling."""

            key = make_template_fragment_key(name, [args, kwargs])
            with app.pool.acquire_channel(block=True) as (_, channel):
                # Mark this key to expire after the countdown, in case our
                # task never runs or runs too many times, we want to clean
                # up our Redis to eventually resolve the issue.
                channel.client.expire(key, countdown + 10)
                depth = channel.client.incr(key)

            LOGGER.debug("Requesting %s in %i seconds (depth=%s)",
                         name, countdown, depth)
            task.si(*args, **kwargs).apply_async(countdown=countdown)

        return app._task_from_fun(debouncer, **options, name=name)

    return factory

Upvotes: 0

David Wolever
David Wolever

Reputation: 154504

Here's the solution I came up with: https://gist.github.com/wolever/3cf2305613052f3810a271e09d42e35c

And copied here, for posterity:

import time

import redis


def get_redis_connection():
    return redis.connect()

class TaskDebouncer(object):
    """ A simple Celery task debouncer.

        Usage::

            def debounce_process_corpus(corpus):
                # Only one task with ``key`` will be allowed to execute at a
                # time. For example, if the task was resizing an image, the key
                # might be the image's URL.
                key = "process_corpus:%s" %(corpus.id, )
                TaskDebouncer.delay(
                    key, my_taks, args=[corpus.id], countdown=0,
                )

            @task(bind=True)
            def process_corpus(self, corpus_id, debounce_key=None):
                debounce = TaskDebouncer(debounce_key, keepalive=30)

                corpus = Corpus.load(corpus_id)

                try:
                    for item in corpus:
                        item.process()

                        # If ``debounce.keepalive()`` isn't called every
                        # ``keepalive`` interval (the ``keepalive=30`` in the
                        # call to ``TaskDebouncer(...)``) the task will be
                        # considered dead and another one will be allowed to
                        # start.
                        debounce.keepalive()

                finally:
                    # ``finalize()`` will mark the task as complete and allow
                    # subsequent tasks to execute. If it returns true, there
                    # was another attempt to start a task with the same key
                    # while this task was running. Depending on your business
                    # logic, this might indicate that the task should be
                    # retried.
                    needs_retry = debounce.finalize()

                if needs_retry:
                    raise self.retry(max_retries=None)

    """

    def __init__(self, key, keepalive=60):
        if key:
            self.key = key.partition("!")[0]
            self.run_key = key
        else:
            self.key = None
            self.run_key = None
        self._keepalive = keepalive
        self.cxn = get_redis_connection()
        self.init()
        self.keepalive()

    @classmethod
    def delay(cls, key, task, args=None, kwargs=None, countdown=30):
        cxn = get_redis_connection()
        now = int(time.time())
        first = cxn.set(key, now, nx=True, ex=countdown + 10)
        if not first:
            now = cxn.get(key)

        run_key = "%s!%s" %(key, now)
        if first:
            kwargs = dict(kwargs or {})
            kwargs["debounce_key"] = run_key
            task.apply_async(args=args, kwargs=kwargs, countdown=countdown)

        return (first, run_key)

    def init(self):
        self.initial = self.key and self.cxn.get(self.key)

    def keepalive(self, expire=None):
        if self.key is None:
            return
        expire = expire if expire is not None else self._keepalive
        self.cxn.expire(self.key, expire)

    def is_out_of_date(self):
        if self.key is None:
            return False
        return self.cxn.get(self.key) != self.initial

    def finalize(self):
        if self.key is None:
            return False
        with self.cxn.pipeline() as pipe:
            while True:
                try:
                    pipe.watch(self.key)
                    if pipe.get(self.key) != self.initial:
                        return True
                    pipe.multi()
                    pipe.delete(self.key)
                    pipe.execute()
                    break
                except redis.WatchError:
                    continue
        return False

Upvotes: 1

dalore
dalore

Reputation: 5834

bartek has the idea, use redis counters which are atomic (and should be easily available if your broker is redis). Although his solution is thottling, not debouncing. The difference is minor though (getset vs decr).

Queue up the task:

conn = get_redis()
conn.incr(key)
task.apply_async(args=args, kwargs=kwargs, countdown=countdown)

Then in the task:

conn = get_redis()
counter = conn.decr(key)
if counter > 0:
    # task is still queued
    return
# continue on to rest of task

It's hard to make it a decorator since you need to decorate the task and calling the task itself. So you will need a decorator before the celery @task decorator and one after it.

For now I'm just made some functions that help me call the task, and one that checks in the start of the task.

Upvotes: 4

nathan-m
nathan-m

Reputation: 8865

Here's how you can do it with Mongo.

NOTE: I had to make the design a little more forgiving, as Celery tasks aren't guaranteed to execute the exact moment of eta is met or countdown runs out.

Also, Mongo expiring indexes are only cleaned up every minute or so; So you can't base the design around records being deleted the moment the eta is up.

Anyhow, the flow is something like this:

  1. Client code calls my_task.
  2. preflight increments a call counter, and returns it as flight_id
  3. _my_task is set to be executed after TTL seconds.
  4. When _my_task runs, it checks if it's flight_id is still current. If it's not, it aborts.
  5. ... sometime later... mongo cleans up stale entries in the collection, via an expiring index.

@celery.task(track_started=False, ignore_result=True)
def my_task(my_arg):
    flight_id = preflight(inflight_collection, 'my_task', HASH(my_arg), TTL)
    _my_task.apply_async((my_arg,), {'flight_id':flight_id}, countdown=TTL)

@celery.task(track_started=False, ignore_result=True)
def _my_task(my_arg, flight_id=None):
    if not check_for_takeoff(inflight_collection, 'my_task', HASH(my_arg), flight_id):
        return
    # ... actual work ... #

Library code:

TTL = 5 * 60     # Run tasks after 5 minutes
EXPIRY = 6 * TTL # This needs to be much larger than TTL. 

# We need to store a list of task-executions currently pending
inflight_collection = db['celery_In_Flight']
inflight_collection.create_index([('fn', pymongo.ASCENDING,),
                                  ('key', pymongo.ASCENDING,)])
inflight_collection.create_index('eta', expiresAfterSeconds=EXPIRY)


def preflight(collection, fn, key, ttl):
    eta = datetime.datetime.now() + datetime.timedelta(seconds=ttl)
    result = collection.find_one_and_update({
        'fn': fn,
        'key': key,
    }, {
        '$set': {
            'eta': eta
        },
        '$inc': {
            'flightId': 1
        }
    }, upsert=True, return_document=pymongo.ReturnDocument.AFTER)
    print 'Preflight[{}][{}] = {}'.format(fn, key, result['flightId'])
    return result['flightId']


def check_for_takeoff(collection, fn, key, flight_id):
    result = collection.find_one({
        'fn': fn,
        'key': key
    })
    ready = result is None or result['flightId'] == flight_id
    print 'Check[{}][{}] = {}, {}'.format(fn, key, result['flightId'], ready)
    return ready

Upvotes: 1

Bartek
Bartek

Reputation: 15609

Here's how we do it with Redis counters. All of this can probably be generalized in a decorator but we only use it for a specific task (webhooks)

Your public-facing task is what you call from other functions. It'll need to increment a key in Redis. The key is formed by the arguments of your function, whatever they may be (this ensures the counter is unique amongst individual tasks)

@task
def your_public_task(*args, **kwargs):
    cache_key = make_public_task_cache_key(*args, **kwargs)
    get_redis().incr(cache_key)
    _your_task(*args, **kwargs, countdown=settings.QUEUE_DELAY)

Note the cache key functions are shared (you want the same cache key in each function), and the countdown setting.

Then, the actual task executing the code does the following:

@task
def _your_task(*args, **kwargs):
    cache_key = make_public_task_cache_key(*args, **kwargs)
    counter = get_redis().getset(cache_key, 0)
    # redis makes the zero a string.
    if counter == '0':
       return

    ... execute your actual task code.

This lets you hit your_public_task.delay(..) as many times as you want, within your QUEUE_DELAY, and it'll only fire off once.

Upvotes: 14

Related Questions