Skip to content

Add extensions field to ExecutionResult (#188) #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ async def execute(
f"{result_text}"
)

return ExecutionResult(errors=result.get("errors"), data=result.get("data"))
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)

def subscribe(
self,
Expand Down
4 changes: 3 additions & 1 deletion gql/transport/phoenix_channel_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def _parse_answer(
answer_type = "data"

execution_result = ExecutionResult(
errors=payload.get("errors"), data=result.get("data")
errors=payload.get("errors"),
data=result.get("data"),
extensions=payload.get("extensions"),
)

elif event == "phx_reply":
Expand Down
6 changes: 5 additions & 1 deletion gql/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def execute( # type: ignore
if "errors" not in result and "data" not in result:
raise TransportProtocolError("Server did not return a GraphQL result")

return ExecutionResult(errors=result.get("errors"), data=result.get("data"))
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)

def close(self):
"""Closing the transport by closing the inner session"""
Expand Down
4 changes: 3 additions & 1 deletion gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def _parse_answer(
)

execution_result = ExecutionResult(
errors=payload.get("errors"), data=payload.get("data")
errors=payload.get("errors"),
data=payload.get("data"),
extensions=payload.get("extensions"),
)

elif answer_type == "error":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import setup, find_packages

install_requires = [
"graphql-core>=3.1,<3.2",
"graphql-core>=3.1.4,<3.2",
"yarl>=1.6,<2.0",
]

Expand Down
32 changes: 32 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,3 +948,35 @@ async def handler(request):
expected_error = "Syntax Error: Unexpected Name 'BLAHBLAH'"

assert expected_error in captured_err


query1_server_answer_with_extensions = (
f'{{"data":{query1_server_answer_data}, "extensions":{{"key1": "val1"}}}}'
)


@pytest.mark.asyncio
async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server):
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(
text=query1_server_answer_with_extensions, content_type="application/json"
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

sample_transport = AIOHTTPTransport(url=url, timeout=10)

async with Client(transport=sample_transport,) as session:

query = gql(query1_str)

execution_result = await session._execute(query)

assert execution_result.extensions["key1"] == "val1"
44 changes: 44 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,47 @@ def test_code():
sample_transport.execute(query)

await run_sync_test(event_loop, server, test_code)


query1_server_answer_with_extensions = (
'{"data":{"continents":['
'{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},'
'{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},'
'{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},'
'{"code":"SA","name":"South America"}]},'
'"extensions": {"key1": "val1"}'
"}"
)


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_requests_query_with_extensions(
event_loop, aiohttp_server, run_sync_test
):
from aiohttp import web
from gql.transport.requests import RequestsHTTPTransport

async def handler(request):
return web.Response(
text=query1_server_answer_with_extensions, content_type="application/json"
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
sample_transport = RequestsHTTPTransport(url=url)

with Client(transport=sample_transport,) as session:

query = gql(query1_str)

execution_result = session._execute(query)

assert execution_result.extensions["key1"] == "val1"

await run_sync_test(event_loop, server, test_code)
30 changes: 30 additions & 0 deletions tests/test_websocket_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,33 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys):
received_answer = json.loads(captured_out)

assert received_answer == expected_answer


query1_server_answer_with_extensions = (
'{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":['
'{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},'
'{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},'
'{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},'
'{{"code":"SA","name":"South America"}}]}},'
'"extensions": {{"key1": "val1"}}}}}}'
)

server1_answers_with_extensions = [
query1_server_answer_with_extensions,
]


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True)
@pytest.mark.parametrize("query_str", [query1_str])
async def test_websocket_simple_query_with_extensions(
event_loop, client_and_server, query_str
):

session, server = client_and_server

query = gql(query_str)

execution_result = await session._execute(query)

assert execution_result.extensions["key1"] == "val1"