Reputation: 11365
How would I reuse a socket to a server in asyncio? Instead of creating a new connection for each query?
Here is my code;
async def lookup(server, port, query, sema):
async with sema as sema:
try:
reader, writer = await asyncio.open_connection(server, port)
except:
return {}
writer.write(query.encode("ISO-8859-1"))
await writer.drain()
data = b""
while True:
d = await reader.read(4096)
if not d:
break
data += d
writer.close()
data = data.decode("ISO-8859-1")
return data
Upvotes: 4
Views: 3608
Reputation: 1121924
You'd simply call the asyncio.open_connection(server, port)
coroutine just once, and keep using the reader and writer (provided the server doesn't just close the connection on their end, of course).
I'd do so in a separate async context manager object for your connections, and use a connection pool to manage the connections, so you can create and re-use socket connections for many concurrent tasks. By using an (async) context manager, Python makes sure to notify the connection when your code is done with it, so the connection can be released back to the pool:
import asyncio
import contextlib
from collections import OrderedDict
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type
try: # Python 3.7
base = contextlib.AbstractAsyncContextManager
except AttributeError:
base = object # type: ignore
Server = str
Port = int
Host = Tuple[Server, Port]
class ConnectionPool(base):
def __init__(
self,
max_connections: int = 1000,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self.max_connections = max_connections
self._loop = loop or asyncio.get_event_loop()
self._connections: OrderedDict[Host, List["Connection"]] = OrderedDict()
self._semaphore = asyncio.Semaphore(max_connections)
async def connect(self, server: Server, port: Port) -> "Connection":
host = (server, port)
# enforce the connection limit, releasing connections notifies
# the semaphore to release here
await self._semaphore.acquire()
connections = self._connections.setdefault(host, [])
# find an un-used connection for this host
connection = next((conn for conn in connections if not conn.in_use), None)
if connection is None:
# disconnect the least-recently-used un-used connection to make space
# for a new connection. There will be at least one.
for conns_per_host in reversed(self._connections.values()):
for conn in conns_per_host:
if not conn.in_use:
await conn.close()
break
reader, writer = await asyncio.open_connection(server, port)
connection = Connection(self, host, reader, writer)
connections.append(connection)
connection.in_use = True
# move current host to the front as most-recently used
self._connections.move_to_end(host, False)
return connection
async def close(self):
"""Close all connections"""
connections = [c for cs in self._connections.values() for c in cs]
self._connections = OrderedDict()
for connection in connections:
await connection.close()
def _remove(self, connection):
conns_for_host = self._connections.get(connection._host)
if not conns_for_host:
return
conns_for_host[:] = [c for c in conns_for_host if c != connection]
def _notify_release(self):
self._semaphore.release()
async def __aenter__(self) -> "ConnectionPool":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self.close()
def __del__(self) -> None:
connections = [repr(c) for cs in self._connections.values() for c in cs]
if not connections:
return
context = {
"pool": self,
"connections": connections,
"message": "Unclosed connection pool",
}
self._loop.call_exception_handler(context)
class Connection(base):
def __init__(
self,
pool: ConnectionPool,
host: Host,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
self._host = host
self._pool = pool
self._reader = reader
self._writer = writer
self._closed = False
self.in_use = False
def __repr__(self):
host = f"{self._host[0]}:{self._host[1]}"
return f"Connection<{host}>"
@property
def closed(self):
return self._closed
def release(self) -> None:
self.in_use = False
self._pool._notify_release()
async def close(self) -> None:
if self._closed:
return
self._closed = True
self._writer.close()
self._pool._remove(self)
try:
await self._writer.wait_closed()
except AttributeError: # wait_closed is new in 3.7
pass
def __getattr__(self, name: str) -> Any:
"""All unknown attributes are delegated to the reader and writer"""
if self._closed or not self.in_use:
raise ValueError("Can't use a closed or unacquired connection")
if hasattr(self._reader, name):
return getattr(self._reader, name)
return getattr(self._writer, name)
async def __aenter__(self) -> "Connection":
if self._closed or not self.in_use:
raise ValueError("Can't use a closed or unacquired connection")
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
self.release()
def __del__(self) -> None:
if self._closed:
return
context = {"connection": self, "message": "Unclosed connection"}
self._pool._loop.call_exception_handler(context)
then pass in a pool object to your lookup coroutine; the connection object produced proxies for both the reader and writer parts:
async def lookup(pool, server, port, query):
try:
conn = await pool.connect(server, port)
except (ValueError, OSError):
return {}
async with conn:
conn.write(query.encode("ISO-8859-1"))
await conn.drain()
data = b""
while True:
d = await conn.read(4096)
if not d:
break
data += d
data = data.decode("ISO-8859-1")
return data
Note that the standard WHOIS protocol (RFC 3912 or predecessors states that the connection is closed after every query. If you are connecting to a standard WHOIS service on port 43, there is no point in re-using sockets.
What happens in this case is that the reader will have reached EOF (reader.at_eof()
is true), and any further attempts at reading will simply return nothing (reader.read(...)
will always return an empty b''
value). Writing to the writer is not going to be an error until the socket connection is terminated by the remote side after a time-out. You can write all you want to the connection, but the WHOIS server will just ignore the queries.
Upvotes: 5
Reputation: 154916
You can create a connection cache by storing the reader/writer pairs to a global dictionary.
# at top-level
connections = {}
Then in lookup
, replace the call to open_connection
with code that checks the dict first:
if (server, port) in connections:
reader, writer = connections[server, port]
else:
reader, writer = await asyncio.open_connection(server, port)
connections[server, port] = reader, writer
Upvotes: 2