From 1027e73aca033c832abbdc9d370c1e617db11d89 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 28 Jun 2020 14:42:12 +0200 Subject: [PATCH] Fix race condition and add non regression test --- gql/transport/websockets.py | 18 ++++++++++++++---- tests/test_websocket_exceptions.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 081c677b..5b7c1217 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -134,6 +134,8 @@ def __init__( self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() + self._connecting: bool = False + self.close_exception: Optional[Exception] = None async def _send(self, message: str) -> None: @@ -467,7 +469,11 @@ async def connect(self) -> None: GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws") - if self.websocket is None: + if self.websocket is None and not self._connecting: + + # Set connecting to True to avoid a race condition if user is trying + # to connect twice using the same client at the same time + self._connecting = True # If the ssl parameter is not provided, # generate the ssl value depending on the url @@ -489,9 +495,13 @@ async def connect(self) -> None: # Connection to the specified url # Generate a TimeoutError if taking more than connect_timeout seconds - self.websocket = await asyncio.wait_for( - websockets.connect(self.url, **connect_args,), self.connect_timeout, - ) + # Set the _connecting flag to False after in all cases + try: + self.websocket = await asyncio.wait_for( + websockets.connect(self.url, **connect_args,), self.connect_timeout, + ) + finally: + self._connecting = False self.next_query_id = 1 self.close_exception = None diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index a5482d43..699cf7ce 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -7,6 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( + TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportQueryError, @@ -294,3 +295,31 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server) # Invalid server message is ignored async with Client(transport=sample_transport): await asyncio.sleep(2 * MS) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_websocket_non_regression_bug_105(event_loop, server): + + # This test will check a fix to a race condition which happens if the user is trying + # to connect using the same client twice at the same time + # See bug #105 + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + # Create a coroutine which start the connection with the transport but does nothing + async def client_connect(client): + async with client: + await asyncio.sleep(2 * MS) + + # Create two tasks which will try to connect using the same client (not allowed) + connect_task1 = asyncio.ensure_future(client_connect(client)) + connect_task2 = asyncio.ensure_future(client_connect(client)) + + with pytest.raises(TransportAlreadyConnected): + await asyncio.gather(connect_task1, connect_task2)