Skip to content

Commit 9627c7c

Browse files
committed
add subprotocol for token-authenticated websockets
follows kubernetes' example of smuggling the token in the subprotocol itself
1 parent 0adfb2a commit 9627c7c

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

jupyter_server/auth/identity.py

+13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from http.cookies import Morsel
2020

2121
from tornado import escape, httputil, web
22+
from tornado.websocket import WebSocketHandler
2223
from traitlets import Bool, Dict, Type, Unicode, default
2324
from traitlets.config import LoggingConfigurable
2425

@@ -106,6 +107,9 @@ def _backward_compat_user(got_user: t.Any) -> User:
106107
raise ValueError(msg)
107108

108109

110+
_TOKEN_SUBPROTOCOL = "v1.token.websocket.jupyter.org"
111+
112+
109113
class IdentityProvider(LoggingConfigurable):
110114
"""
111115
Interface for providing identity management and authentication.
@@ -424,6 +428,15 @@ def get_token(self, handler: web.RequestHandler) -> str | None:
424428
m = self.auth_header_pat.match(handler.request.headers.get("Authorization", ""))
425429
if m:
426430
user_token = m.group(2)
431+
if not user_token and isinstance(handler, WebSocketHandler):
432+
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
433+
if subprotocol_header:
434+
subprotocols = [s.strip() for s in subprotocol_header.split(",")]
435+
for subprotocol in subprotocols:
436+
if subprotocol.startswith(_TOKEN_SUBPROTOCOL + "."):
437+
user_token = subprotocol[len(_TOKEN_SUBPROTOCOL) + 1 :]
438+
break
439+
427440
return user_token
428441

429442
async def get_user_token(self, handler: web.RequestHandler) -> User | None:

tests/base/test_websocket.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tornado.websocket import WebSocketClosedError, WebSocketHandler
1111

1212
from jupyter_server.auth import IdentityProvider, User
13-
from jupyter_server.auth.decorator import allow_unauthenticated
13+
from jupyter_server.auth.decorator import allow_unauthenticated, ws_authenticated
1414
from jupyter_server.base.handlers import JupyterHandler
1515
from jupyter_server.base.websocket import WebSocketMixin
1616
from jupyter_server.serverapp import ServerApp
@@ -75,6 +75,12 @@ class NoAuthRulesWebsocketHandler(MockJupyterHandler):
7575
pass
7676

7777

78+
class AuthenticatedWebsocketHandler(MockJupyterHandler):
79+
@ws_authenticated
80+
def get(self, *args, **kwargs) -> None:
81+
return super().get(*args, **kwargs)
82+
83+
7884
class PermissiveWebsocketHandler(MockJupyterHandler):
7985
@allow_unauthenticated
8086
def get(self, *args, **kwargs) -> None:
@@ -126,6 +132,30 @@ async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch):
126132
assert exception.value.code == 403
127133

128134

135+
async def test_websocket_token_subprotocol_auth(jp_serverapp, jp_ws_fetch):
136+
app: ServerApp = jp_serverapp
137+
app.web_app.add_handlers(
138+
".*$",
139+
[
140+
(url_path_join(app.base_url, "ws"), AuthenticatedWebsocketHandler),
141+
],
142+
)
143+
144+
with pytest.raises(HTTPClientError) as exception:
145+
ws = await jp_ws_fetch("ws", headers={"Authorization": ""})
146+
assert exception.value.code == 403
147+
token = jp_serverapp.identity_provider.token
148+
ws = await jp_ws_fetch(
149+
"ws",
150+
headers={
151+
"Authorization": "",
152+
"Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org."
153+
+ token,
154+
},
155+
)
156+
ws.close()
157+
158+
129159
class IndiscriminateIdentityProvider(IdentityProvider):
130160
async def get_user(self, handler):
131161
return User(username="test")

0 commit comments

Comments
 (0)