diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index b1a33ad2..cdc4571b 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -236,7 +236,11 @@ async def execute( f"{result_text}" ) - return ExecutionResult(errors=result.get("errors"), data=result.get("data")) + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) def subscribe( self, diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index aaa6686a..557636db 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -189,7 +189,9 @@ def _parse_answer( answer_type = "data" execution_result = ExecutionResult( - errors=payload.get("errors"), data=result.get("data") + errors=payload.get("errors"), + data=result.get("data"), + extensions=payload.get("extensions"), ) elif event == "phx_reply": diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 5eb2a36c..c7d03adb 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -171,7 +171,11 @@ def execute( # type: ignore if "errors" not in result and "data" not in result: raise TransportProtocolError("Server did not return a GraphQL result") - return ExecutionResult(errors=result.get("errors"), data=result.get("data")) + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) def close(self): """Closing the transport by closing the inner session""" diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 76a234bd..e7eb4e8f 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -303,7 +303,9 @@ def _parse_answer( ) execution_result = ExecutionResult( - errors=payload.get("errors"), data=payload.get("data") + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), ) elif answer_type == "error": diff --git a/setup.py b/setup.py index fdcaccf0..496e7f3f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.1,<3.2", + "graphql-core>=3.1.4,<3.2", "yarl>=1.6,<2.0", ] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 0bf8c1ba..815b4904 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -948,3 +948,35 @@ async def handler(request): expected_error = "Syntax Error: Unexpected Name 'BLAHBLAH'" assert expected_error in captured_err + + +query1_server_answer_with_extensions = ( + f'{{"data":{query1_server_answer_data}, "extensions":{{"key1": "val1"}}}}' +) + + +@pytest.mark.asyncio +async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + execution_result = await session._execute(query) + + assert execution_result.extensions["key1"] == "val1" diff --git a/tests/test_requests.py b/tests/test_requests.py index 99d40bf1..a0f8ca27 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -216,3 +216,47 @@ def test_code(): sample_transport.execute(query) await run_sync_test(event_loop, server, test_code) + + +query1_server_answer_with_extensions = ( + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_with_extensions( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + execution_result = session._execute(query) + + assert execution_result.extensions["key1"] == "val1" + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index fc89dc80..e825c637 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -569,3 +569,33 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): received_answer = json.loads(captured_out) assert received_answer == expected_answer + + +query1_server_answer_with_extensions = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}},' + '"extensions": {{"key1": "val1"}}}}}}' +) + +server1_answers_with_extensions = [ + query1_server_answer_with_extensions, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_simple_query_with_extensions( + event_loop, client_and_server, query_str +): + + session, server = client_and_server + + query = gql(query_str) + + execution_result = await session._execute(query) + + assert execution_result.extensions["key1"] == "val1"