Skip to content

Commit 5348b0a

Browse files
Fix race condition and add non regression test (#106)
Co-authored-by: Manuel Bojato <[email protected]>
1 parent 87cc5b2 commit 5348b0a

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

gql/transport/websockets.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def __init__(
134134
self._no_more_listeners: asyncio.Event = asyncio.Event()
135135
self._no_more_listeners.set()
136136

137+
self._connecting: bool = False
138+
137139
self.close_exception: Optional[Exception] = None
138140

139141
async def _send(self, message: str) -> None:
@@ -472,7 +474,11 @@ async def connect(self) -> None:
472474

473475
GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws")
474476

475-
if self.websocket is None:
477+
if self.websocket is None and not self._connecting:
478+
479+
# Set connecting to True to avoid a race condition if user is trying
480+
# to connect twice using the same client at the same time
481+
self._connecting = True
476482

477483
# If the ssl parameter is not provided,
478484
# generate the ssl value depending on the url
@@ -494,9 +500,13 @@ async def connect(self) -> None:
494500

495501
# Connection to the specified url
496502
# Generate a TimeoutError if taking more than connect_timeout seconds
497-
self.websocket = await asyncio.wait_for(
498-
websockets.connect(self.url, **connect_args,), self.connect_timeout,
499-
)
503+
# Set the _connecting flag to False after in all cases
504+
try:
505+
self.websocket = await asyncio.wait_for(
506+
websockets.connect(self.url, **connect_args,), self.connect_timeout,
507+
)
508+
finally:
509+
self._connecting = False
500510

501511
self.next_query_id = 1
502512
self.close_exception = None

tests/test_websocket_exceptions.py

+29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from gql import Client, gql
1010
from gql.transport.exceptions import (
11+
TransportAlreadyConnected,
1112
TransportClosed,
1213
TransportProtocolError,
1314
TransportQueryError,
@@ -319,3 +320,31 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server)
319320
# Invalid server message is ignored
320321
async with Client(transport=sample_transport):
321322
await asyncio.sleep(2 * MS)
323+
324+
325+
@pytest.mark.asyncio
326+
@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True)
327+
async def test_websocket_non_regression_bug_105(event_loop, server):
328+
329+
# This test will check a fix to a race condition which happens if the user is trying
330+
# to connect using the same client twice at the same time
331+
# See bug #105
332+
333+
url = f"ws://{server.hostname}:{server.port}/graphql"
334+
print(f"url = {url}")
335+
336+
sample_transport = WebsocketsTransport(url=url)
337+
338+
client = Client(transport=sample_transport)
339+
340+
# Create a coroutine which start the connection with the transport but does nothing
341+
async def client_connect(client):
342+
async with client:
343+
await asyncio.sleep(2 * MS)
344+
345+
# Create two tasks which will try to connect using the same client (not allowed)
346+
connect_task1 = asyncio.ensure_future(client_connect(client))
347+
connect_task2 = asyncio.ensure_future(client_connect(client))
348+
349+
with pytest.raises(TransportAlreadyConnected):
350+
await asyncio.gather(connect_task1, connect_task2)

0 commit comments

Comments
 (0)