Jonathan
Jonathan

Reputation: 11365

Asyncio how to reuse a socket

How would I reuse a socket to a server in asyncio? Instead of creating a new connection for each query?

Here is my code;

async def lookup(server, port, query, sema):
    async with sema as sema:
        try:
            reader, writer = await asyncio.open_connection(server, port)
        except:
            return {}
        writer.write(query.encode("ISO-8859-1"))
        await writer.drain()
        data = b""
        while True:
            d = await reader.read(4096)
            if not d:
                break
            data += d
        writer.close()
        data = data.decode("ISO-8859-1")
        return data

Upvotes: 4

Views: 3608

Answers (2)

Martijn Pieters
Martijn Pieters

Reputation: 1121924

You'd simply call the asyncio.open_connection(server, port) coroutine just once, and keep using the reader and writer (provided the server doesn't just close the connection on their end, of course).

I'd do so in a separate async context manager object for your connections, and use a connection pool to manage the connections, so you can create and re-use socket connections for many concurrent tasks. By using an (async) context manager, Python makes sure to notify the connection when your code is done with it, so the connection can be released back to the pool:

import asyncio
import contextlib

from collections import OrderedDict
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type


try:  # Python 3.7
    base = contextlib.AbstractAsyncContextManager
except AttributeError:
    base = object  # type: ignore

Server = str
Port = int
Host = Tuple[Server, Port]


class ConnectionPool(base):
    def __init__(
        self,
        max_connections: int = 1000,
        loop: Optional[asyncio.AbstractEventLoop] = None,
    ):
        self.max_connections = max_connections
        self._loop = loop or asyncio.get_event_loop()

        self._connections: OrderedDict[Host, List["Connection"]] = OrderedDict()
        self._semaphore = asyncio.Semaphore(max_connections)

    async def connect(self, server: Server, port: Port) -> "Connection":
        host = (server, port)

        # enforce the connection limit, releasing connections notifies
        # the semaphore to release here
        await self._semaphore.acquire()

        connections = self._connections.setdefault(host, [])
        # find an un-used connection for this host
        connection = next((conn for conn in connections if not conn.in_use), None)
        if connection is None:
            # disconnect the least-recently-used un-used connection to make space
            # for a new connection. There will be at least one.
            for conns_per_host in reversed(self._connections.values()):
                for conn in conns_per_host:
                    if not conn.in_use:
                        await conn.close()
                        break

            reader, writer = await asyncio.open_connection(server, port)
            connection = Connection(self, host, reader, writer)
            connections.append(connection)

        connection.in_use = True
        # move current host to the front as most-recently used
        self._connections.move_to_end(host, False)

        return connection

    async def close(self):
        """Close all connections"""
        connections = [c for cs in self._connections.values() for c in cs]
        self._connections = OrderedDict()
        for connection in connections:
            await connection.close()

    def _remove(self, connection):
        conns_for_host = self._connections.get(connection._host)
        if not conns_for_host:
            return
        conns_for_host[:] = [c for c in conns_for_host if c != connection]

    def _notify_release(self):
        self._semaphore.release()

    async def __aenter__(self) -> "ConnectionPool":
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        tb: Optional[TracebackType],
    ) -> None:
        await self.close()

    def __del__(self) -> None:
        connections = [repr(c) for cs in self._connections.values() for c in cs]
        if not connections:
            return

        context = {
            "pool": self,
            "connections": connections,
            "message": "Unclosed connection pool",
        }
        self._loop.call_exception_handler(context)


class Connection(base):
    def __init__(
        self,
        pool: ConnectionPool,
        host: Host,
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
    ):
        self._host = host
        self._pool = pool
        self._reader = reader
        self._writer = writer
        self._closed = False
        self.in_use = False

    def __repr__(self):
        host = f"{self._host[0]}:{self._host[1]}"
        return f"Connection<{host}>"

    @property
    def closed(self):
        return self._closed

    def release(self) -> None:
        self.in_use = False
        self._pool._notify_release()

    async def close(self) -> None:
        if self._closed:
            return
        self._closed = True
        self._writer.close()
        self._pool._remove(self)
        try:
            await self._writer.wait_closed()
        except AttributeError:  # wait_closed is new in 3.7
            pass

    def __getattr__(self, name: str) -> Any:
        """All unknown attributes are delegated to the reader and writer"""
        if self._closed or not self.in_use:
            raise ValueError("Can't use a closed or unacquired connection")
        if hasattr(self._reader, name):
            return getattr(self._reader, name)
        return getattr(self._writer, name)

    async def __aenter__(self) -> "Connection":
        if self._closed or not self.in_use:
            raise ValueError("Can't use a closed or unacquired connection")
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        tb: Optional[TracebackType],
    ) -> None:
        self.release()

    def __del__(self) -> None:
        if self._closed:
            return
        context = {"connection": self, "message": "Unclosed connection"}
        self._pool._loop.call_exception_handler(context)

then pass in a pool object to your lookup coroutine; the connection object produced proxies for both the reader and writer parts:

async def lookup(pool, server, port, query):
    try:
        conn = await pool.connect(server, port)
    except (ValueError, OSError):
        return {}

    async with conn:
        conn.write(query.encode("ISO-8859-1"))
        await conn.drain()
        data = b""
        while True:
            d = await conn.read(4096)
            if not d:
                break
            data += d
        data = data.decode("ISO-8859-1")
        return data

Note that the standard WHOIS protocol (RFC 3912 or predecessors states that the connection is closed after every query. If you are connecting to a standard WHOIS service on port 43, there is no point in re-using sockets.

What happens in this case is that the reader will have reached EOF (reader.at_eof() is true), and any further attempts at reading will simply return nothing (reader.read(...) will always return an empty b'' value). Writing to the writer is not going to be an error until the socket connection is terminated by the remote side after a time-out. You can write all you want to the connection, but the WHOIS server will just ignore the queries.

Upvotes: 5

user4815162342
user4815162342

Reputation: 154916

You can create a connection cache by storing the reader/writer pairs to a global dictionary.

# at top-level
connections = {}

Then in lookup, replace the call to open_connection with code that checks the dict first:

if (server, port) in connections:
    reader, writer = connections[server, port]
else:
    reader, writer = await asyncio.open_connection(server, port)
    connections[server, port] = reader, writer

Upvotes: 2

Related Questions