|
| 1 | +""" Reverse-proxy customized for jupyter notebooks |
| 2 | +
|
| 3 | +TODO: document |
| 4 | +""" |
| 5 | + |
| 6 | +import asyncio |
| 7 | +import logging |
| 8 | +import pprint |
| 9 | + |
| 10 | +import aiohttp |
| 11 | +from aiohttp import client, web |
| 12 | + |
| 13 | +# TODO: find actual name in registry |
| 14 | +SUPPORTED_IMAGE_NAME = "jupyter" |
| 15 | +SUPPORTED_IMAGE_TAG = "==0.1.0" |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +async def handler(req: web.Request, service_url: str, **_kwargs) -> web.StreamResponse: |
| 21 | + # Resolved url pointing to backend jupyter service |
| 22 | + tarfind_url = service_url + req.path_qs |
| 23 | + |
| 24 | + reqH = req.headers.copy() |
| 25 | + if reqH['connection'] == 'Upgrade' and reqH['upgrade'] == 'websocket' and req.method == 'GET': |
| 26 | + |
| 27 | + ws_server = web.WebSocketResponse() |
| 28 | + await ws_server.prepare(req) |
| 29 | + logger.info('##### WS_SERVER %s', pprint.pformat(ws_server)) |
| 30 | + |
| 31 | + client_session = aiohttp.ClientSession(cookies=req.cookies) |
| 32 | + async with client_session.ws_connect( |
| 33 | + tarfind_url, |
| 34 | + ) as ws_client: |
| 35 | + logger.info('##### WS_CLIENT %s', pprint.pformat(ws_client)) |
| 36 | + |
| 37 | + async def ws_forward(ws_from, ws_to): |
| 38 | + async for msg in ws_from: |
| 39 | + logger.info('>>> msg: %s', pprint.pformat(msg)) |
| 40 | + mt = msg.type |
| 41 | + md = msg.data |
| 42 | + if mt == aiohttp.WSMsgType.TEXT: |
| 43 | + await ws_to.send_str(md) |
| 44 | + elif mt == aiohttp.WSMsgType.BINARY: |
| 45 | + await ws_to.send_bytes(md) |
| 46 | + elif mt == aiohttp.WSMsgType.PING: |
| 47 | + await ws_to.ping() |
| 48 | + elif mt == aiohttp.WSMsgType.PONG: |
| 49 | + await ws_to.pong() |
| 50 | + elif ws_to.closed: |
| 51 | + await ws_to.close(code=ws_to.close_code, message=msg.extra) |
| 52 | + else: |
| 53 | + raise ValueError( |
| 54 | + 'unexpected message type: %s' % pprint.pformat(msg)) |
| 55 | + |
| 56 | + await asyncio.wait([ws_forward(ws_server, ws_client), ws_forward(ws_client, ws_server)], return_when=asyncio.FIRST_COMPLETED) |
| 57 | + |
| 58 | + return ws_server |
| 59 | + else: |
| 60 | + |
| 61 | + async with client.request( |
| 62 | + req.method, tarfind_url, |
| 63 | + headers=reqH, |
| 64 | + allow_redirects=False, |
| 65 | + data=await req.read() |
| 66 | + ) as res: |
| 67 | + headers = res.headers.copy() |
| 68 | + body = await res.read() |
| 69 | + return web.Response( |
| 70 | + headers=headers, |
| 71 | + status=res.status, |
| 72 | + body=body |
| 73 | + ) |
| 74 | + return ws_server |
| 75 | + |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + # dummies for manual testing |
| 79 | + BASE_URL = 'http://0.0.0.0:8888' |
| 80 | + MOUNT_POINT = '/x/fakeUuid' |
| 81 | + |
| 82 | + def adapter(req: web.Request): |
| 83 | + return handler(req, service_url=BASE_URL) |
| 84 | + |
| 85 | + app = web.Application() |
| 86 | + app.router.add_route('*', MOUNT_POINT + '/{proxyPath:.*}', adapter) |
| 87 | + web.run_app(app, port=3984) |
0 commit comments