diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 6004442d..4367aef6 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -243,16 +243,14 @@ async def _send_query( query_id = self.next_query_id self.next_query_id += 1 + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + query_str = json.dumps( - { - "id": str(query_id), - "type": "start", - "payload": { - "variables": variable_values or {}, - "operationName": operation_name or "", - "query": print_ast(document), - }, - } + {"id": str(query_id), "type": "start", "payload": payload} ) await self._send(query_str) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 93cbde4b..2e42cb1c 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -320,6 +320,34 @@ async def test_websocket_subscription_slow_consumer( assert count == -1 +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_operation_name( + event_loop, client_and_server, subscription_str +): + + session, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + WITH_KEEPALIVE = True