diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 1e464104..0c6eb3fc 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,7 +1,7 @@ import io import json import logging -from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union import requests from graphql import DocumentNode, ExecutionResult, print_ast @@ -47,6 +47,8 @@ def __init__( method: str = "POST", retry_backoff_factor: float = 0.1, retry_status_forcelist: Collection[int] = _default_retry_codes, + json_serialize: Callable = json.dumps, + json_deserialize: Callable = json.loads, **kwargs: Any, ): """Initialize the transport with the given request parameters. @@ -73,6 +75,10 @@ def __init__( should force a retry on. A retry is initiated if the request method is in allowed_methods and the response status code is in status_forcelist. (Default: [429, 500, 502, 503, 504]) + :param json_serialize: Json serializer callable. + By default json.dumps() function + :param json_deserialize: Json deserializer callable. + By default json.loads() function :param kwargs: Optional arguments that ``request`` takes. These can be seen at the `requests`_ source code or the official `docs`_ @@ -90,6 +96,8 @@ def __init__( self.method = method self.retry_backoff_factor = retry_backoff_factor self.retry_status_forcelist = retry_status_forcelist + self.json_serialize: Callable = json_serialize + self.json_deserialize: Callable = json_deserialize self.kwargs = kwargs self.session = None @@ -174,7 +182,7 @@ def execute( # type: ignore payload["variables"] = nulled_variable_values # Add the payload to the operations field - operations_str = json.dumps(payload) + operations_str = self.json_serialize(payload) log.debug("operations %s", operations_str) # Generate the file map @@ -188,7 +196,7 @@ def execute( # type: ignore file_streams = {str(i): files[path] for i, path in enumerate(files)} # Add the file map field - file_map_str = json.dumps(file_map) + file_map_str = self.json_serialize(file_map) log.debug("file_map %s", file_map_str) fields = {"operations": operations_str, "map": file_map_str} @@ -224,7 +232,7 @@ def execute( # type: ignore # Log the payload if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(payload)) + log.info(">>> %s", self.json_serialize(payload)) # Pass kwargs to requests post method post_args.update(self.kwargs) @@ -257,7 +265,10 @@ def raise_response_error(resp: requests.Response, reason: str): ) try: - result = response.json() + if self.json_deserialize == json.loads: + result = response.json() + else: + result = self.json_deserialize(response.text) if log.isEnabledFor(logging.INFO): log.info("<<< %s", response.text) @@ -396,7 +407,7 @@ def _build_batch_post_args( # Log the payload if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(post_args[data_key])) + log.info(">>> %s", self.json_serialize(post_args[data_key])) # Pass kwargs to requests post method post_args.update(self.kwargs) diff --git a/tests/test_requests.py b/tests/test_requests.py index 639d2b73..ba666243 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -923,3 +923,109 @@ def test_code(): assert transport.session is None await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_json_serializer( + event_loop, aiohttp_server, run_sync_test, caplog +): + import json + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + + request_text = await request.text() + print("Received on backend: " + request_text) + + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport( + url=url, + json_serialize=lambda e: json.dumps(e, separators=(",", ":")), + ) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking that there is no space after the colon in the log + expected_log = '"query":"query getContinents' + assert expected_log in caplog.text + + await run_sync_test(event_loop, server, test_code) + + +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_test): + import json + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = RequestsHTTPTransport( + url=url, + json_deserialize=json_loads, + ) + + with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197") + + await run_sync_test(event_loop, server, test_code)