|
10 | 10 | from tornado.websocket import WebSocketClosedError, WebSocketHandler
|
11 | 11 |
|
12 | 12 | from jupyter_server.auth import IdentityProvider, User
|
13 |
| -from jupyter_server.auth.decorator import allow_unauthenticated |
| 13 | +from jupyter_server.auth.decorator import allow_unauthenticated, ws_authenticated |
14 | 14 | from jupyter_server.base.handlers import JupyterHandler
|
15 | 15 | from jupyter_server.base.websocket import WebSocketMixin
|
16 | 16 | from jupyter_server.serverapp import ServerApp
|
@@ -75,6 +75,12 @@ class NoAuthRulesWebsocketHandler(MockJupyterHandler):
|
75 | 75 | pass
|
76 | 76 |
|
77 | 77 |
|
| 78 | +class AuthenticatedWebsocketHandler(MockJupyterHandler): |
| 79 | + @ws_authenticated |
| 80 | + def get(self, *args, **kwargs) -> None: |
| 81 | + return super().get(*args, **kwargs) |
| 82 | + |
| 83 | + |
78 | 84 | class PermissiveWebsocketHandler(MockJupyterHandler):
|
79 | 85 | @allow_unauthenticated
|
80 | 86 | def get(self, *args, **kwargs) -> None:
|
@@ -126,6 +132,30 @@ async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch):
|
126 | 132 | assert exception.value.code == 403
|
127 | 133 |
|
128 | 134 |
|
| 135 | +async def test_websocket_token_subprotocol_auth(jp_serverapp, jp_ws_fetch): |
| 136 | + app: ServerApp = jp_serverapp |
| 137 | + app.web_app.add_handlers( |
| 138 | + ".*$", |
| 139 | + [ |
| 140 | + (url_path_join(app.base_url, "ws"), AuthenticatedWebsocketHandler), |
| 141 | + ], |
| 142 | + ) |
| 143 | + |
| 144 | + with pytest.raises(HTTPClientError) as exception: |
| 145 | + ws = await jp_ws_fetch("ws", headers={"Authorization": ""}) |
| 146 | + assert exception.value.code == 403 |
| 147 | + token = jp_serverapp.identity_provider.token |
| 148 | + ws = await jp_ws_fetch( |
| 149 | + "ws", |
| 150 | + headers={ |
| 151 | + "Authorization": "", |
| 152 | + "Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org." |
| 153 | + + token, |
| 154 | + }, |
| 155 | + ) |
| 156 | + ws.close() |
| 157 | + |
| 158 | + |
129 | 159 | class IndiscriminateIdentityProvider(IdentityProvider):
|
130 | 160 | async def get_user(self, handler):
|
131 | 161 | return User(username="test")
|
|
0 commit comments