Reputation: 3530
Is there any official / Better approach to add callback on nursery's task add/remove, rather than wrapping the trio._core._run.GLOBAL_RUN_CONTEXT.runner.tasks
set?
With the purpose of studying the internals of trio.Nursery
(and just for fun), I'm trying to get the task count on every task add/remove of one specific nursery.
What I found is, trio uses single Set
in trio._core._run.GLOBAL_RUN_CONTEXT.runner
as attribute tasks
.
So I tried wrapping that Set
as replacing it caused handler error on exit.
Demonstration code:
from collections.abc import MutableSet
from typing import Iterator
import random
import time
import trio
class SetWithCallback(MutableSet):
"""
Class to wrap around existing set for adding callback support
"""
def __init__(self, original_set: set, add_callback=None, remove_callback=None):
self.inner_set = original_set
self.add_callback = add_callback if add_callback else lambda _: None
self.remove_callback = remove_callback if remove_callback else lambda _: None
def add(self, value) -> None:
self.inner_set.add(value)
self.add_callback(value)
def remove(self, value) -> None:
self.inner_set.remove(value)
self.remove_callback(value)
def discard(self, value) -> None:
self.inner_set.discard(value)
def __contains__(self, x: object) -> bool:
return x in self.inner_set
def __len__(self) -> int:
return len(self.inner_set)
def __iter__(self) -> Iterator:
return iter(self.inner_set)
async def dummy_task(task_no, lifetime):
"""
A very meaningful and serious workload
Args:
task_no: Task's ID
lifetime: Lifetime of task
"""
print(f" Task {task_no} started, expecting lifetime of {lifetime}s!")
start = time.time()
await trio.sleep(lifetime)
print(f" Task {task_no} finished, actual lifetime was {time.time() - start:.6}s!")
async def main():
# Wrap original tasks set with our new class
# noinspection PyProtectedMember
runner = trio._core._run.GLOBAL_RUN_CONTEXT.runner
# noinspection PyTypeChecker
runner.tasks = SetWithCallback(runner.tasks)
async with trio.open_nursery() as nursery:
# callback to be called on every task add/remove.
# checks if task belongs to nursery.
def add_callback(task):
# do something
# child tasks count + 1 because given task is yet to be added.
print(f"Task {id(task)} added, {len(nursery.child_tasks) + 1} in nursery {id(nursery)}")
def remove_callback(task):
# do something
# child tasks count - 1 because given task is yet to be removed from it.
print(f"Task {id(task)} done, {len(nursery.child_tasks) - 1} in nursery {id(nursery)}")
# replace default callback to count the task count in nursery.
runner.tasks.add_callback = add_callback
runner.tasks.remove_callback = remove_callback
# spawn tasks
for n in range(5):
nursery.start_soon(dummy_task, n, random.randint(2, 5))
await trio.sleep(1)
if __name__ == '__main__':
trio.run(main)
Task 2436520969376 added, 1 in nursery 2436520962768
Task 0 started, expecting lifetime of 3s!
Task 2436520969536 added, 2 in nursery 2436520962768
Task 1 started, expecting lifetime of 2s!
Task 2436520969856 added, 3 in nursery 2436520962768
Task 2 started, expecting lifetime of 2s!
Task 0 finished, actual lifetime was 3.00393s!
Task 2436520969376 done, 2 in nursery 2436520962768
Task 1 finished, actual lifetime was 2.01102s!
Task 2436520969536 done, 1 in nursery 2436520962768
Task 2436520969536 added, 2 in nursery 2436520962768
Task 3 started, expecting lifetime of 5s!
Task 2 finished, actual lifetime was 2.01447s!
Task 2436520969856 done, 1 in nursery 2436520962768
Task 2436520969856 added, 2 in nursery 2436520962768
Task 4 started, expecting lifetime of 3s!
Task 4 finished, actual lifetime was 3.01358s!
Task 2436520969856 done, 1 in nursery 2436520962768
Task 3 finished, actual lifetime was 5.01383s!
Task 2436520969536 done, 0 in nursery 2436520962768
Task 2436520969056 done, -1 in nursery 2436520962768
Task 2436520969216 done, -1 in nursery 2436520962768
Task 2436520968896 done, -1 in nursery 2436520962768
Process finished with exit code 0
Results looks promising, all good. Until we realize it's literally Global Context.
Any additional & core nurseries running around would call this callback a lot and gives amazing tasks count like -1.
Best solution I could think of was checking if task belongs to nursery:
def add_callback(task):
# do something
if task in nursery.child_tasks: # <---- what??
# child tasks count + 1 because given task is yet to be added.
print(f"Task {id(task)} added, {len(nursery.child_tasks) + 1} in nursery {id(nursery)}")
def remove_callback(task):
# do something
if task in nursery.child_tasks:
# child tasks count - 1 because given task is yet to be removed from it.
print(f"Task {id(task)} done, {len(nursery.child_tasks) - 1} in nursery {id(nursery)}")
But this obviously can't check a barely-existing task that's just trying to be added in nursery.
Upvotes: 2
Views: 195
Reputation: 5911
Have you considered trio instrumentation as a solution?
from collections import defaultdict
from trio import open_nursery
from trio.abc import Instrument
from trio.lowlevel import add_instrument
task_counts = defaultdict(lambda: 0)
class TaskCountInstrument(Instrument):
def task_spawned(self, task):
task_counts[task.parent_nursery] += 1
def task_exited(self, task):
task_counts[task.parent_nursery] -= 1
async def main():
add_instrument(TaskCountInstrument())
async with open_nursery() as nursery:
...
Upvotes: 4