Skip to content

Fix race condition when trying to connect twice at the same time with the websockets transport #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def __init__(
self._no_more_listeners: asyncio.Event = asyncio.Event()
self._no_more_listeners.set()

self._connecting: bool = False

self.close_exception: Optional[Exception] = None

async def _send(self, message: str) -> None:
Expand Down Expand Up @@ -472,7 +474,11 @@ async def connect(self) -> None:

GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws")

if self.websocket is None:
if self.websocket is None and not self._connecting:

# Set connecting to True to avoid a race condition if user is trying
# to connect twice using the same client at the same time
self._connecting = True

# If the ssl parameter is not provided,
# generate the ssl value depending on the url
Expand All @@ -494,9 +500,13 @@ async def connect(self) -> None:

# Connection to the specified url
# Generate a TimeoutError if taking more than connect_timeout seconds
self.websocket = await asyncio.wait_for(
websockets.connect(self.url, **connect_args,), self.connect_timeout,
)
# Set the _connecting flag to False after in all cases
try:
self.websocket = await asyncio.wait_for(
websockets.connect(self.url, **connect_args,), self.connect_timeout,
)
finally:
self._connecting = False

self.next_query_id = 1
self.close_exception = None
Expand Down
29 changes: 29 additions & 0 deletions tests/test_websocket_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from gql import Client, gql
from gql.transport.exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportQueryError,
Expand Down Expand Up @@ -319,3 +320,31 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server)
# Invalid server message is ignored
async with Client(transport=sample_transport):
await asyncio.sleep(2 * MS)


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True)
async def test_websocket_non_regression_bug_105(event_loop, server):

# This test will check a fix to a race condition which happens if the user is trying
# to connect using the same client twice at the same time
# See bug #105

url = f"ws://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

sample_transport = WebsocketsTransport(url=url)

client = Client(transport=sample_transport)

# Create a coroutine which start the connection with the transport but does nothing
async def client_connect(client):
async with client:
await asyncio.sleep(2 * MS)

# Create two tasks which will try to connect using the same client (not allowed)
connect_task1 = asyncio.ensure_future(client_connect(client))
connect_task2 = asyncio.ensure_future(client_connect(client))

with pytest.raises(TransportAlreadyConnected):
await asyncio.gather(connect_task1, connect_task2)