Skip to content

Commit 1027e73

Browse files
committed
Fix race condition and add non regression test
1 parent ce13236 commit 1027e73

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

gql/transport/websockets.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def __init__(
134134
self._no_more_listeners: asyncio.Event = asyncio.Event()
135135
self._no_more_listeners.set()
136136

137+
self._connecting: bool = False
138+
137139
self.close_exception: Optional[Exception] = None
138140

139141
async def _send(self, message: str) -> None:
@@ -467,7 +469,11 @@ async def connect(self) -> None:
467469

468470
GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws")
469471

470-
if self.websocket is None:
472+
if self.websocket is None and not self._connecting:
473+
474+
# Set connecting to True to avoid a race condition if user is trying
475+
# to connect twice using the same client at the same time
476+
self._connecting = True
471477

472478
# If the ssl parameter is not provided,
473479
# generate the ssl value depending on the url
@@ -489,9 +495,13 @@ async def connect(self) -> None:
489495

490496
# Connection to the specified url
491497
# Generate a TimeoutError if taking more than connect_timeout seconds
492-
self.websocket = await asyncio.wait_for(
493-
websockets.connect(self.url, **connect_args,), self.connect_timeout,
494-
)
498+
# Set the _connecting flag to False after in all cases
499+
try:
500+
self.websocket = await asyncio.wait_for(
501+
websockets.connect(self.url, **connect_args,), self.connect_timeout,
502+
)
503+
finally:
504+
self._connecting = False
495505

496506
self.next_query_id = 1
497507
self.close_exception = None

tests/test_websocket_exceptions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from gql import Client, gql
99
from gql.transport.exceptions import (
10+
TransportAlreadyConnected,
1011
TransportClosed,
1112
TransportProtocolError,
1213
TransportQueryError,
@@ -294,3 +295,31 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server)
294295
# Invalid server message is ignored
295296
async with Client(transport=sample_transport):
296297
await asyncio.sleep(2 * MS)
298+
299+
300+
@pytest.mark.asyncio
301+
@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True)
302+
async def test_websocket_non_regression_bug_105(event_loop, server):
303+
304+
# This test will check a fix to a race condition which happens if the user is trying
305+
# to connect using the same client twice at the same time
306+
# See bug #105
307+
308+
url = f"ws://{server.hostname}:{server.port}/graphql"
309+
print(f"url = {url}")
310+
311+
sample_transport = WebsocketsTransport(url=url)
312+
313+
client = Client(transport=sample_transport)
314+
315+
# Create a coroutine which start the connection with the transport but does nothing
316+
async def client_connect(client):
317+
async with client:
318+
await asyncio.sleep(2 * MS)
319+
320+
# Create two tasks which will try to connect using the same client (not allowed)
321+
connect_task1 = asyncio.ensure_future(client_connect(client))
322+
connect_task2 = asyncio.ensure_future(client_connect(client))
323+
324+
with pytest.raises(TransportAlreadyConnected):
325+
await asyncio.gather(connect_task1, connect_task2)

0 commit comments

Comments
 (0)