Skip to content

Commit d4c9751

Browse files
authored
Add sync batching to requests sync transport (#431)
* Add `execute_batch` method for requests sync transport
1 parent 013fa6a commit d4c9751

12 files changed

+1621
-37
lines changed

gql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from .__version__ import __version__
1111
from .client import Client
1212
from .gql import gql
13+
from .graphql_request import GraphQLRequest
1314

1415
__all__ = [
1516
"__version__",
1617
"gql",
1718
"Client",
19+
"GraphQLRequest",
1820
]

gql/client.py

Lines changed: 164 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Callable,
99
Dict,
1010
Generator,
11+
List,
1112
Optional,
1213
TypeVar,
1314
Union,
@@ -27,6 +28,7 @@
2728
validate,
2829
)
2930

31+
from .graphql_request import GraphQLRequest
3032
from .transport.async_transport import AsyncTransport
3133
from .transport.exceptions import TransportClosed, TransportQueryError
3234
from .transport.local_schema import LocalSchemaTransport
@@ -236,6 +238,24 @@ def execute_sync(
236238
**kwargs,
237239
)
238240

241+
def execute_batch_sync(
242+
self,
243+
reqs: List[GraphQLRequest],
244+
serialize_variables: Optional[bool] = None,
245+
parse_result: Optional[bool] = None,
246+
get_execution_result: bool = False,
247+
**kwargs,
248+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
249+
""":meta private:"""
250+
with self as session:
251+
return session.execute_batch(
252+
reqs,
253+
serialize_variables=serialize_variables,
254+
parse_result=parse_result,
255+
get_execution_result=get_execution_result,
256+
**kwargs,
257+
)
258+
239259
@overload
240260
async def execute_async(
241261
self,
@@ -375,7 +395,6 @@ def execute(
375395
"""
376396

377397
if isinstance(self.transport, AsyncTransport):
378-
379398
# Get the current asyncio event loop
380399
# Or create a new event loop if there isn't one (in a new Thread)
381400
try:
@@ -418,6 +437,48 @@ def execute(
418437
**kwargs,
419438
)
420439

440+
def execute_batch(
441+
self,
442+
reqs: List[GraphQLRequest],
443+
serialize_variables: Optional[bool] = None,
444+
parse_result: Optional[bool] = None,
445+
get_execution_result: bool = False,
446+
**kwargs,
447+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
448+
"""Execute multiple GraphQL requests in a batch against the remote server using
449+
the transport provided during init.
450+
451+
This function **WILL BLOCK** until the result is received from the server.
452+
453+
Either the transport is sync and we execute the query synchronously directly
454+
OR the transport is async and we execute the query in the asyncio loop
455+
(blocking here until answer).
456+
457+
This method will:
458+
459+
- connect using the transport to get a session
460+
- execute the GraphQL requests on the transport session
461+
- close the session and close the connection to the server
462+
463+
If you want to perform multiple executions, it is better to use
464+
the context manager to keep a session active.
465+
466+
The extra arguments passed in the method will be passed to the transport
467+
execute method.
468+
"""
469+
470+
if isinstance(self.transport, AsyncTransport):
471+
raise NotImplementedError("Batching is not implemented for async yet.")
472+
473+
else: # Sync transports
474+
return self.execute_batch_sync(
475+
reqs,
476+
serialize_variables=serialize_variables,
477+
parse_result=parse_result,
478+
get_execution_result=get_execution_result,
479+
**kwargs,
480+
)
481+
421482
@overload
422483
def subscribe_async(
423484
self,
@@ -476,7 +537,6 @@ async def subscribe_async(
476537
]:
477538
""":meta private:"""
478539
async with self as session:
479-
480540
generator = session.subscribe(
481541
document,
482542
variable_values=variable_values,
@@ -600,7 +660,6 @@ def subscribe(
600660
pass
601661

602662
except (KeyboardInterrupt, Exception, GeneratorExit):
603-
604663
# Graceful shutdown
605664
asyncio.ensure_future(async_generator.aclose(), loop=loop)
606665

@@ -661,11 +720,9 @@ async def close_async(self):
661720
await self.transport.close()
662721

663722
async def __aenter__(self):
664-
665723
return await self.connect_async()
666724

667725
async def __aexit__(self, exc_type, exc, tb):
668-
669726
await self.close_async()
670727

671728
def connect_sync(self):
@@ -705,7 +762,6 @@ def close_sync(self):
705762
self.transport.close()
706763

707764
def __enter__(self):
708-
709765
return self.connect_sync()
710766

711767
def __exit__(self, *args):
@@ -880,6 +936,108 @@ def execute(
880936

881937
return result.data
882938

939+
def _execute_batch(
940+
self,
941+
reqs: List[GraphQLRequest],
942+
serialize_variables: Optional[bool] = None,
943+
parse_result: Optional[bool] = None,
944+
**kwargs,
945+
) -> List[ExecutionResult]:
946+
"""Execute multiple GraphQL requests in a batch, using
947+
the sync transport, returning a list of ExecutionResult objects.
948+
949+
:param reqs: List of requests that will be executed.
950+
:param serialize_variables: whether the variable values should be
951+
serialized. Used for custom scalars and/or enums.
952+
By default use the serialize_variables argument of the client.
953+
:param parse_result: Whether gql will unserialize the result.
954+
By default use the parse_results argument of the client.
955+
956+
The extra arguments are passed to the transport execute method."""
957+
958+
# Validate document
959+
if self.client.schema:
960+
for req in reqs:
961+
self.client.validate(req.document)
962+
963+
# Parse variable values for custom scalars if requested
964+
if serialize_variables or (
965+
serialize_variables is None and self.client.serialize_variables
966+
):
967+
reqs = [
968+
req.serialize_variable_values(self.client.schema)
969+
if req.variable_values is not None
970+
else req
971+
for req in reqs
972+
]
973+
974+
results = self.transport.execute_batch(reqs, **kwargs)
975+
976+
# Unserialize the result if requested
977+
if self.client.schema:
978+
if parse_result or (parse_result is None and self.client.parse_results):
979+
for result in results:
980+
result.data = parse_result_fn(
981+
self.client.schema,
982+
req.document,
983+
result.data,
984+
operation_name=req.operation_name,
985+
)
986+
987+
return results
988+
989+
def execute_batch(
990+
self,
991+
reqs: List[GraphQLRequest],
992+
serialize_variables: Optional[bool] = None,
993+
parse_result: Optional[bool] = None,
994+
get_execution_result: bool = False,
995+
**kwargs,
996+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
997+
"""Execute multiple GraphQL requests in a batch, using
998+
the sync transport. This method sends the requests to the server all at once.
999+
1000+
Raises a TransportQueryError if an error has been returned in any
1001+
ExecutionResult.
1002+
1003+
:param reqs: List of requests that will be executed.
1004+
:param serialize_variables: whether the variable values should be
1005+
serialized. Used for custom scalars and/or enums.
1006+
By default use the serialize_variables argument of the client.
1007+
:param parse_result: Whether gql will unserialize the result.
1008+
By default use the parse_results argument of the client.
1009+
:param get_execution_result: return the full ExecutionResult instance instead of
1010+
only the "data" field. Necessary if you want to get the "extensions" field.
1011+
1012+
The extra arguments are passed to the transport execute method."""
1013+
1014+
# Validate and execute on the transport
1015+
results = self._execute_batch(
1016+
reqs,
1017+
serialize_variables=serialize_variables,
1018+
parse_result=parse_result,
1019+
**kwargs,
1020+
)
1021+
1022+
for result in results:
1023+
# Raise an error if an error is returned in the ExecutionResult object
1024+
if result.errors:
1025+
raise TransportQueryError(
1026+
str_first_element(result.errors),
1027+
errors=result.errors,
1028+
data=result.data,
1029+
extensions=result.extensions,
1030+
)
1031+
1032+
assert (
1033+
result.data is not None
1034+
), "Transport returned an ExecutionResult without data or errors"
1035+
1036+
if get_execution_result:
1037+
return results
1038+
1039+
return cast(List[Dict[str, Any]], [result.data for result in results])
1040+
8831041
def fetch_schema(self) -> None:
8841042
"""Fetch the GraphQL schema explicitly using introspection.
8851043
@@ -966,7 +1124,6 @@ async def _subscribe(
9661124

9671125
try:
9681126
async for result in inner_generator:
969-
9701127
if self.client.schema:
9711128
if parse_result or (
9721129
parse_result is None and self.client.parse_results
@@ -1070,7 +1227,6 @@ async def subscribe(
10701227
try:
10711228
# Validate and subscribe on the transport
10721229
async for result in inner_generator:
1073-
10741230
# Raise an error if an error is returned in the ExecutionResult object
10751231
if result.errors:
10761232
raise TransportQueryError(
@@ -1343,7 +1499,6 @@ async def _connection_loop(self):
13431499
"""
13441500

13451501
while True:
1346-
13471502
# Connect to the transport with the retry decorator
13481503
# By default it should keep retrying until it connect
13491504
await self._connect_with_retries()

gql/graphql_request.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Dict, Optional
3+
4+
from graphql import DocumentNode, GraphQLSchema
5+
6+
from .utilities import serialize_variable_values
7+
8+
9+
@dataclass(frozen=True)
10+
class GraphQLRequest:
11+
"""GraphQL Request to be executed."""
12+
13+
document: DocumentNode
14+
"""GraphQL query as AST Node object."""
15+
16+
variable_values: Optional[Dict[str, Any]] = None
17+
"""Dictionary of input parameters (Default: None)."""
18+
19+
operation_name: Optional[str] = None
20+
"""
21+
Name of the operation that shall be executed.
22+
Only required in multi-operation documents (Default: None).
23+
"""
24+
25+
def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
26+
assert self.variable_values
27+
28+
return GraphQLRequest(
29+
document=self.document,
30+
variable_values=serialize_variable_values(
31+
schema=schema,
32+
document=self.document,
33+
variable_values=self.variable_values,
34+
operation_name=self.operation_name,
35+
),
36+
operation_name=self.operation_name,
37+
)

gql/transport/aiohttp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ async def execute(
205205
document: DocumentNode,
206206
variable_values: Optional[Dict[str, Any]] = None,
207207
operation_name: Optional[str] = None,
208-
extra_args: Dict[str, Any] = None,
208+
extra_args: Optional[Dict[str, Any]] = None,
209209
upload_files: bool = False,
210210
) -> ExecutionResult:
211211
"""Execute the provided document AST against the configured remote server

0 commit comments

Comments
 (0)