Skip to content

Commit dc12af3

Browse files
committed
Merge branch 'tim/SessionClientDisconnectedError' into tim/StreamlitRuntime
* tim/SessionClientDisconnectedError: better docstring typo SessionClientDisconnectedError, with test WIP * Remove staticmethod decorator from __call__ method of SingletonAPI to avoid mypy bug python/mypy#7781 (streamlit#4981)
2 parents ba0df18 + 88266e1 commit dc12af3

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

lib/streamlit/caching/singleton_decorator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,24 @@ class SingletonAPI:
122122

123123
# Bare decorator usage
124124
@overload
125-
@staticmethod
126-
def __call__(func: F) -> F:
125+
def __call__(self, func: F) -> F:
127126
...
128127

129128
# Decorator with arguments
130129
@overload
131-
@staticmethod
132130
def __call__(
131+
self,
133132
*,
134133
show_spinner: bool = True,
135134
suppress_st_warning=False,
136135
) -> Callable[[F], F]:
137136
...
138137

139-
@staticmethod
138+
# __call__ should be a static method, but there's a mypy bug that
139+
# breaks type checking for overloaded static functions:
140+
# https://github.com/python/mypy/issues/7781
140141
def __call__(
142+
self,
141143
func: Optional[F] = None,
142144
*,
143145
show_spinner: bool = True,

lib/streamlit/web/server/server.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,21 @@
119119
SCRIPT_RUN_CHECK_TIMEOUT = 60
120120

121121

122+
class SessionClientDisconnectedError(Exception):
123+
"""Raised by operations on a disconnected SessionClient."""
124+
125+
pass
126+
127+
122128
class SessionClient(Protocol):
123129
"""Interface for sending data to a session's client."""
124130

125131
def write_forward_msg(self, msg: ForwardMsg) -> None:
126-
"""Deliver a ForwardMsg to the client."""
132+
"""Deliver a ForwardMsg to the client.
133+
134+
If the SessionClient has been disconnected, it should raise a
135+
SessionClientDisconnectedError.
136+
"""
127137

128138

129139
class SessionInfo:
@@ -515,7 +525,7 @@ async def _loop_coroutine(
515525
for msg in msg_list:
516526
try:
517527
self._send_message(session_info, msg)
518-
except tornado.websocket.WebSocketClosedError:
528+
except SessionClientDisconnectedError:
519529
self._close_app_session(session_info.session.id)
520530
await asyncio.sleep(0)
521531
await asyncio.sleep(0)
@@ -722,7 +732,10 @@ def check_origin(self, origin: str) -> bool:
722732

723733
def write_forward_msg(self, msg: ForwardMsg) -> None:
724734
"""Send a ForwardMsg to the browser."""
725-
self.write_message(serialize_forward_msg(msg), binary=True)
735+
try:
736+
self.write_message(serialize_forward_msg(msg), binary=True)
737+
except tornado.websocket.WebSocketClosedError as e:
738+
raise SessionClientDisconnectedError from e
726739

727740
def open(self, *args, **kwargs) -> Optional[Awaitable[None]]:
728741
# Extract user info from the X-Streamlit-User header

lib/tests/streamlit/web/server/server_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,48 @@ async def test_orphaned_upload_file_deletion(self):
372372
[],
373373
)
374374

375+
@tornado.testing.gen_test
376+
async def test_send_message_to_disconnected_websocket(self):
377+
"""Sending a message to a disconnected SessionClient raises an error.
378+
We should gracefully handle the error by cleaning up the session.
379+
"""
380+
with patch(
381+
"streamlit.web.server.server.LocalSourcesWatcher"
382+
), self._patch_app_session():
383+
await self.start_server_loop()
384+
await self.ws_connect()
385+
386+
# Get the server's socket and session for this client
387+
session_info = list(self.server._session_info_by_id.values())[0]
388+
389+
with patch.object(
390+
session_info.session, "flush_browser_queue"
391+
) as flush_browser_queue, patch.object(
392+
session_info.client, "write_message"
393+
) as ws_write_message:
394+
# Patch flush_browser_queue to simulate a pending message.
395+
flush_browser_queue.return_value = [_create_dataframe_msg([1, 2, 3])]
396+
397+
# Patch the session's WebsocketHandler to raise a
398+
# WebSocketClosedError when we write to it.
399+
ws_write_message.side_effect = tornado.websocket.WebSocketClosedError()
400+
401+
# Tick the server. Our session's browser_queue will be flushed,
402+
# and the Websocket client's write_message will be called,
403+
# raising our WebSocketClosedError.
404+
while not flush_browser_queue.called:
405+
self.server._need_send_data.set()
406+
await asyncio.sleep(0)
407+
408+
flush_browser_queue.assert_called_once()
409+
ws_write_message.assert_called_once()
410+
411+
# Our session should have been removed from the server as
412+
# a result of the WebSocketClosedError.
413+
self.assertIsNone(
414+
self.server._get_session_info(session_info.session.id)
415+
)
416+
375417

376418
class ServerUtilsTest(unittest.TestCase):
377419
def test_is_url_from_allowed_origins_allowed_domains(self):

0 commit comments

Comments
 (0)