Tobi Oetiker
Tobi Oetiker

Reputation: 5470

how to implement a websocket aware reverse-proxy with aiohttp (python 3.6)

I am trying to implement an application specific reverse-proxy for jupyter notebooks using aiohttp. It works fine for http requests, but the websocket forwarding does not work. Requests from the browser arrive and get forwarded, but there are no responses from jupyter forthcoming. I assume my websocket client code somehow does not react to incoming messages from jupyter.

The only indication on the jupyter side that something is amiss are messages like this:

WebSocket ping timeout after 90009 ms.

so here is my attempt at writing the proxy

from aiohttp import web
from aiohttp import client
import aiohttp
import logging
import pprint

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


baseUrl = 'http://0.0.0.0:8888'
mountPoint = '/fakeUuid'

async def handler(req):
    proxyPath = req.match_info.get('proxyPath','no proxyPath placeholder defined')
    reqH = req.headers.copy()
    if reqH['connection'] == 'Upgrade' and reqH['upgrade'] == 'websocket' and req.method == 'GET':

      ws_server = web.WebSocketResponse()
      await ws_server.prepare(req)
      logger.info('##### WS_SERVER %s' % pprint.pformat(ws_server))

      client_session = aiohttp.ClientSession()
      async with client_session.ws_connect(baseUrl+req.path_qs,
        headers = { 'cookie': reqH['cookie'] },
      ) as ws_client:
        logger.info('##### WS_CLIENT %s' % pprint.pformat(ws_client))

        async for server_msg in ws_server:
          logger.info('>>> msg from browser: %s',pprint.pformat(server_msg))
          if server_msg.type == aiohttp.WSMsgType.TEXT:
            await ws_client.send_str(server_msg.data)
          else:
            await ws_client.send_bytes(server_msg.data)

        async for client_msg in ws_client:
          logger.info('>>> msg from jupyter: %s',pprint.pformat(client_msg))
          if client_msg.tp == aiohttp.WSMsgType.TEXT:
            await ws_server.send_str(client_msg.data)
          else:
            await ws_server.send_bytes(client_msg.data)

        return ws_server
    else:
      async with client.request(
          req.method,baseUrl+mountPoint+proxyPath,
          headers = reqH,
          allow_redirects=False,
          data = await req.read()
      ) as res:
          headers = res.headers.copy()
          body = await res.read()
          return web.Response(
            headers = headers,
            status = res.status,
            body = body
          )
      return ws_server

app = web.Application()
app.router.add_route('*',mountPoint + '{proxyPath:.*}', handler)
web.run_app(app,port=3984)

Upvotes: 2

Views: 1697

Answers (1)

Tobi Oetiker
Tobi Oetiker

Reputation: 5470

Lesson learned: the two async for are blocking in the flow of the current function. By running them with asyncio.wait I can get them to run at the same time. The resulting program looks like this:

from aiohttp import web
from aiohttp import client
import aiohttp
import asyncio
import logging
import pprint

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


baseUrl = 'http://0.0.0.0:8888'
mountPoint = '/fakeUuid'

async def handler(req):
    proxyPath = req.match_info.get('proxyPath','no proxyPath placeholder defined')
    reqH = req.headers.copy()
    if reqH['connection'] == 'Upgrade' and reqH['upgrade'] == 'websocket' and req.method == 'GET':

      ws_server = web.WebSocketResponse()
      await ws_server.prepare(req)
      logger.info('##### WS_SERVER %s' % pprint.pformat(ws_server))

      client_session = aiohttp.ClientSession(cookies=req.cookies)
      async with client_session.ws_connect(
        baseUrl+req.path_qs,
      },
      ) as ws_client:
        logger.info('##### WS_CLIENT %s' % pprint.pformat(ws_client))

        async def wsforward(ws_from,ws_to):
          async for msg in ws_from:
            logger.info('>>> msg: %s',pprint.pformat(msg))
            mt = msg.type
            md = msg.data
            if mt == aiohttp.WSMsgType.TEXT:
              await ws_to.send_str(md)
            elif mt == aiohttp.WSMsgType.BINARY:
              await ws_to.send_bytes(md)
            elif mt == aiohttp.WSMsgType.PING:
              await ws_to.ping()
            elif mt == aiohttp.WSMsgType.PONG:
              await ws_to.pong()
            elif ws_to.closed:
              await ws_to.close(code=ws_to.close_code,message=msg.extra)
            else:
              raise ValueError('unexpecte message type: %s',pprint.pformat(msg))

        finished,unfinished = await asyncio.wait([wsforward(ws_server,ws_client),wsforward(ws_client,ws_server)],return_when=asyncio.FIRST_COMPLETED)

        return ws_server
    else:
      async with client.request(
          req.method,baseUrl+mountPoint+proxyPath,
          headers = reqH,
          allow_redirects=False,
          data = await req.read()
      ) as res:
          headers = res.headers.copy()
          body = await res.read()
          return web.Response(
            headers = headers,
            status = res.status,
            body = body
          )
      return ws_server

app = web.Application()
app.router.add_route('*',mountPoint + '{proxyPath:.*}', handler)
web.run_app(app,port=3984)

Upvotes: 2

Related Questions