Recursing
Recursing

Reputation: 648

Async Context Manager that performs an action on context switch

I would like to make a context manager for async functions that calls a function every time the execution "moves" to another context

E.g.

import os
import asyncio

class AsyncContextChangeDir:
    def __init__(self, newdir):
        self.curdir = os.getcwd()
        self.newdir = newdir

    async def __aenter__(self):
        os.chdir(self.newdir)

    async def __aexit__(self, exc_type, exc_value, traceback):
        os.chdir(self.curdir)

async def workon_mypath():
    async with AsyncContextChangeDir("/tmp"):
        print("working in /tmp context manager, cwd:" + os.getcwd()) # /mypath
        await asyncio.sleep(100)
        print("working in /tmp context manager, cwd:" + os.getcwd()) # ???

async def workon_someotherpath():
    await asyncio.sleep(10)
    os.chdir("/home")
    print("working in other context cwd:" + os.getcwd())


loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(
    workon_mypath(),
    workon_someotherpath()))

I would like the second print to print /mypath, and obviously to restore the previous working directory every time the execution "switches" to another context

What's the best way to do this?

Upvotes: 2

Views: 857

Answers (2)

Hzz
Hzz

Reputation: 1928

I think I found a way to detect context-switch, but as anyone would say, it's a dirty hack and it's very recommended not to use it in production. The other problem is that I might not yet know how to get the specific function currently being executed by a Task. This process involves modifying the BaseEventLoop._run_once method like this example.

Output:

================= Context switching happen for main() in asyncio_testing_2.py, line 9 ================= 
start main
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 16 ================= 
start multiple_calls 1
start example_coroutine 1
start example_coroutine_extended 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 16 ================= 
start multiple_calls 2
start example_coroutine 2
start example_coroutine_extended 2
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 18 ================= 
end example_coroutine_extended 1
start example_coroutine_extended 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 18 ================= 
end example_coroutine_extended 2
start example_coroutine_extended 2
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 18 ================= 
end example_coroutine_extended 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 18 ================= 
end example_coroutine_extended 2
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 18 ================= 
end example_coroutine 1
start example_coroutine 1
start example_coroutine_extended 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 18 ================= 
end example_coroutine 2
start example_coroutine 2
start example_coroutine_extended 2
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 19 ================= 
end example_coroutine_extended 1
start example_coroutine_extended 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 19 ================= 
end example_coroutine_extended 2
start example_coroutine_extended 2
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 19 ================= 
end example_coroutine_extended 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 19 ================= 
end example_coroutine_extended 2
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 19 ================= 
end example_coroutine 1
start blocking_coroutine 1
end blocking_coroutine 1
end multiple_calls 1
================= Context switching happen for multiple_calls() in asyncio_testing_2.py, line 19 ================= 
end example_coroutine 2
start blocking_coroutine 2
end blocking_coroutine 2
end multiple_calls 2
================= Context switching happen for main() in asyncio_testing_2.py, line 13 ================= 
end main

Process finished with exit code 0

Code:

import asyncio
import heapq
import time
# noinspection PyUnresolvedReferences
from asyncio.base_events import MAXIMUM_SELECT_TIMEOUT, _MIN_SCHEDULED_TIMER_HANDLES, _MIN_CANCELLED_TIMER_HANDLES_FRACTION
from types import MethodType


async def main():
    print("start main")
    task1 = asyncio.create_task(multiple_calls(1))
    task2 = asyncio.create_task(multiple_calls(2))
    await asyncio.gather(task1, task2)
    print("end main")

async def multiple_calls(_id):
    print(f"start multiple_calls {_id}")
    await example_coroutine(_id)
    await example_coroutine(_id)
    await blocking_coroutine(_id)
    print(f"end multiple_calls {_id}")


async def example_coroutine(_id):
    print(f"start example_coroutine {_id}")
    await example_coroutine_extended(_id)
    await example_coroutine_extended(_id)
    await asyncio.sleep(1)
    print(f"end example_coroutine {_id}")


async def example_coroutine_extended(_id):
    print(f"start example_coroutine_extended {_id}")
    await asyncio.sleep(1)
    print(f"end example_coroutine_extended {_id}")


async def blocking_coroutine(_id):
    print(f"start blocking_coroutine {_id}")
    time.sleep(1)
    print(f"end blocking_coroutine {_id}")


def _run_once(self):
    """Run one full iteration of the event loop.

    This calls all currently ready callbacks, polls for I/O,
    schedules the resulting callbacks, and finally schedules
    'call_later' callbacks.
    """

    sched_count = len(self._scheduled)
    if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and
            self._timer_cancelled_count / sched_count >
            _MIN_CANCELLED_TIMER_HANDLES_FRACTION):
        # Remove delayed calls that were cancelled if their number
        # is too high
        new_scheduled = []
        for handle in self._scheduled:
            if handle._cancelled:
                handle._scheduled = False
            else:
                new_scheduled.append(handle)

        heapq.heapify(new_scheduled)
        self._scheduled = new_scheduled
        self._timer_cancelled_count = 0
    else:
        # Remove delayed calls that were cancelled from head of queue.
        while self._scheduled and self._scheduled[0]._cancelled:
            self._timer_cancelled_count -= 1
            handle = heapq.heappop(self._scheduled)
            handle._scheduled = False

    timeout = None
    if self._ready or self._stopping:
        timeout = 0
    elif self._scheduled:
        # Compute the desired timeout.
        timeout = self._scheduled[0]._when - self.time()
        if timeout > MAXIMUM_SELECT_TIMEOUT:
            timeout = MAXIMUM_SELECT_TIMEOUT
        elif timeout < 0:
            timeout = 0

    event_list = self._selector.select(timeout)
    self._process_events(event_list)
    # Needed to break cycles when an exception occurs.
    event_list = None

    # Handle 'later' callbacks that are ready.
    end_time = self.time() + self._clock_resolution
    while self._scheduled:
        handle = self._scheduled[0]
        if handle._when >= end_time:
            break
        handle = heapq.heappop(self._scheduled)
        handle._scheduled = False
        self._ready.append(handle)

    # This is the only place where callbacks are actually *called*.
    # All other places just add them to ready.
    # Note: We run all currently scheduled callbacks, but not any
    # callbacks scheduled by callbacks run this time around --
    # they will be run the next time (after another I/O poll).
    # Use an idiom that is thread-safe without using locks.
    ntodo = len(self._ready)
    for i in range(ntodo):
        handle = self._ready.popleft()
        if handle._cancelled:
            continue

        #################### MODIFIED ####################
        task = None
        if type(handle._callback).__name__!='function':
            task = handle._callback.__self__
        if isinstance(task, asyncio.Task):
            on_context_switch_event_handler(handle, task)
        handle._run()
        ##################################################
    handle = None  # Needed to break cycles when an exception occurs.



def on_context_switch_event_handler(handle, task):
    task_code_info = task.get_coro().cr_code
    task_frame_info = task.get_coro().cr_frame

    func_name = get_coro_name(task.get_coro())
    line_pos = task_frame_info.f_lineno
    file_name = task_code_info.co_filename.replace("\\", "/").split("/")[-1]
    print(f"================= Context switching happen for {func_name} in {file_name}, line {line_pos} ================= ")

def get_coro_name(coro):
    # Coroutines compiled with Cython sometimes don't have
    # proper __qualname__ or __name__.  While that is a bug
    # in Cython, asyncio shouldn't crash with an AttributeError
    # in its __repr__ functions.
    if hasattr(coro, '__qualname__') and coro.__qualname__:
        coro_name = coro.__qualname__
    elif hasattr(coro, '__name__') and coro.__name__:
        coro_name = coro.__name__
    else:
        # Stop masking Cython bugs, expose them in a friendly way.
        coro_name = f'<{type(coro).__name__} without __name__>'
    return f'{coro_name}()'

loop = asyncio.get_event_loop()
loop._run_once = MethodType(_run_once, loop)
loop.run_until_complete(main())


Upvotes: 0

user2357112
user2357112

Reputation: 280101

Contrary to what you might expect from the name, context managers as a concept don't have anything to do with context switches.

Neither regular context managers nor asynchronous context managers are informed about event loop "context switches". There is no way for a context manager to detect that the event loop is going to start running another coroutine, and there is no way for a context manager to execute code when that happens.

Upvotes: 0

Related Questions