jupiterbjy
jupiterbjy

Reputation: 3530

Callback when tasks are added/removed in trio.Nursery

Question

Is there any official / Better approach to add callback on nursery's task add/remove, rather than wrapping the trio._core._run.GLOBAL_RUN_CONTEXT.runner.tasks set?


Details

With the purpose of studying the internals of trio.Nursery (and just for fun), I'm trying to get the task count on every task add/remove of one specific nursery.

What I found is, trio uses single Set in trio._core._run.GLOBAL_RUN_CONTEXT.runner as attribute tasks.

So I tried wrapping that Set as replacing it caused handler error on exit.

Demonstration code:

from collections.abc import MutableSet
from typing import Iterator
import random
import time

import trio


class SetWithCallback(MutableSet):
    """
    Class to wrap around existing set for adding callback support
    """
    def __init__(self, original_set: set, add_callback=None, remove_callback=None):
        self.inner_set = original_set

        self.add_callback = add_callback if add_callback else lambda _: None
        self.remove_callback = remove_callback if remove_callback else lambda _: None

    def add(self, value) -> None:
        self.inner_set.add(value)
        self.add_callback(value)

    def remove(self, value) -> None:
        self.inner_set.remove(value)
        self.remove_callback(value)

    def discard(self, value) -> None:
        self.inner_set.discard(value)

    def __contains__(self, x: object) -> bool:
        return x in self.inner_set

    def __len__(self) -> int:
        return len(self.inner_set)

    def __iter__(self) -> Iterator:
        return iter(self.inner_set)


async def dummy_task(task_no, lifetime):
    """
    A very meaningful and serious workload

    Args:
        task_no: Task's ID
        lifetime: Lifetime of task
    """
    print(f"  Task {task_no} started, expecting lifetime of {lifetime}s!")
    start = time.time()
    await trio.sleep(lifetime)
    print(f"  Task {task_no} finished, actual lifetime was {time.time() - start:.6}s!")


async def main():
    # Wrap original tasks set with our new class
    # noinspection PyProtectedMember
    runner = trio._core._run.GLOBAL_RUN_CONTEXT.runner
    # noinspection PyTypeChecker
    runner.tasks = SetWithCallback(runner.tasks)

    async with trio.open_nursery() as nursery:

        # callback to be called on every task add/remove.
        # checks if task belongs to nursery.
        def add_callback(task):
            # do something
            # child tasks count + 1 because given task is yet to be added.
            print(f"Task {id(task)} added, {len(nursery.child_tasks) + 1} in nursery {id(nursery)}")

        def remove_callback(task):
            # do something
            # child tasks count - 1 because given task is yet to be removed from it.
            print(f"Task {id(task)} done, {len(nursery.child_tasks) - 1} in nursery {id(nursery)}")

        # replace default callback to count the task count in nursery.
        runner.tasks.add_callback = add_callback
        runner.tasks.remove_callback = remove_callback

        # spawn tasks
        for n in range(5):
            nursery.start_soon(dummy_task, n, random.randint(2, 5))
            await trio.sleep(1)


if __name__ == '__main__':
    trio.run(main)
Task 2436520969376 added, 1 in nursery 2436520962768
  Task 0 started, expecting lifetime of 3s!
Task 2436520969536 added, 2 in nursery 2436520962768
  Task 1 started, expecting lifetime of 2s!
Task 2436520969856 added, 3 in nursery 2436520962768
  Task 2 started, expecting lifetime of 2s!
  Task 0 finished, actual lifetime was 3.00393s!
Task 2436520969376 done, 2 in nursery 2436520962768
  Task 1 finished, actual lifetime was 2.01102s!
Task 2436520969536 done, 1 in nursery 2436520962768
Task 2436520969536 added, 2 in nursery 2436520962768
  Task 3 started, expecting lifetime of 5s!
  Task 2 finished, actual lifetime was 2.01447s!
Task 2436520969856 done, 1 in nursery 2436520962768
Task 2436520969856 added, 2 in nursery 2436520962768
  Task 4 started, expecting lifetime of 3s!
  Task 4 finished, actual lifetime was 3.01358s!
Task 2436520969856 done, 1 in nursery 2436520962768
  Task 3 finished, actual lifetime was 5.01383s!
Task 2436520969536 done, 0 in nursery 2436520962768
Task 2436520969056 done, -1 in nursery 2436520962768
Task 2436520969216 done, -1 in nursery 2436520962768
Task 2436520968896 done, -1 in nursery 2436520962768

Process finished with exit code 0

Results looks promising, all good. Until we realize it's literally Global Context.

Any additional & core nurseries running around would call this callback a lot and gives amazing tasks count like -1.

Best solution I could think of was checking if task belongs to nursery:

        def add_callback(task):
            # do something
            if task in nursery.child_tasks:  # <---- what??
                # child tasks count + 1 because given task is yet to be added.
                print(f"Task {id(task)} added, {len(nursery.child_tasks) + 1} in nursery {id(nursery)}")

        def remove_callback(task):
            # do something
            if task in nursery.child_tasks:
                # child tasks count - 1 because given task is yet to be removed from it.
                print(f"Task {id(task)} done, {len(nursery.child_tasks) - 1} in nursery {id(nursery)}")

But this obviously can't check a barely-existing task that's just trying to be added in nursery.

Upvotes: 2

Views: 195

Answers (1)

Alex Gr&#246;nholm
Alex Gr&#246;nholm

Reputation: 5911

Have you considered trio instrumentation as a solution?

from collections import defaultdict

from trio import open_nursery
from trio.abc import Instrument
from trio.lowlevel import add_instrument

task_counts = defaultdict(lambda: 0)

class TaskCountInstrument(Instrument):
    def task_spawned(self, task):
        task_counts[task.parent_nursery] += 1

    def task_exited(self, task):
        task_counts[task.parent_nursery] -= 1


async def main():
    add_instrument(TaskCountInstrument())
    async with open_nursery() as nursery:
        ...

Upvotes: 4

Related Questions