|
8 | 8 | Callable,
|
9 | 9 | Dict,
|
10 | 10 | Generator,
|
| 11 | + List, |
11 | 12 | Optional,
|
12 | 13 | TypeVar,
|
13 | 14 | Union,
|
|
27 | 28 | validate,
|
28 | 29 | )
|
29 | 30 |
|
| 31 | +from .graphql_request import GraphQLRequest |
30 | 32 | from .transport.async_transport import AsyncTransport
|
31 | 33 | from .transport.exceptions import TransportClosed, TransportQueryError
|
32 | 34 | from .transport.local_schema import LocalSchemaTransport
|
@@ -236,6 +238,24 @@ def execute_sync(
|
236 | 238 | **kwargs,
|
237 | 239 | )
|
238 | 240 |
|
| 241 | + def execute_batch_sync( |
| 242 | + self, |
| 243 | + reqs: List[GraphQLRequest], |
| 244 | + serialize_variables: Optional[bool] = None, |
| 245 | + parse_result: Optional[bool] = None, |
| 246 | + get_execution_result: bool = False, |
| 247 | + **kwargs, |
| 248 | + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: |
| 249 | + """:meta private:""" |
| 250 | + with self as session: |
| 251 | + return session.execute_batch( |
| 252 | + reqs, |
| 253 | + serialize_variables=serialize_variables, |
| 254 | + parse_result=parse_result, |
| 255 | + get_execution_result=get_execution_result, |
| 256 | + **kwargs, |
| 257 | + ) |
| 258 | + |
239 | 259 | @overload
|
240 | 260 | async def execute_async(
|
241 | 261 | self,
|
@@ -375,7 +395,6 @@ def execute(
|
375 | 395 | """
|
376 | 396 |
|
377 | 397 | if isinstance(self.transport, AsyncTransport):
|
378 |
| - |
379 | 398 | # Get the current asyncio event loop
|
380 | 399 | # Or create a new event loop if there isn't one (in a new Thread)
|
381 | 400 | try:
|
@@ -418,6 +437,48 @@ def execute(
|
418 | 437 | **kwargs,
|
419 | 438 | )
|
420 | 439 |
|
| 440 | + def execute_batch( |
| 441 | + self, |
| 442 | + reqs: List[GraphQLRequest], |
| 443 | + serialize_variables: Optional[bool] = None, |
| 444 | + parse_result: Optional[bool] = None, |
| 445 | + get_execution_result: bool = False, |
| 446 | + **kwargs, |
| 447 | + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: |
| 448 | + """Execute multiple GraphQL requests in a batch against the remote server using |
| 449 | + the transport provided during init. |
| 450 | +
|
| 451 | + This function **WILL BLOCK** until the result is received from the server. |
| 452 | +
|
| 453 | + Either the transport is sync and we execute the query synchronously directly |
| 454 | + OR the transport is async and we execute the query in the asyncio loop |
| 455 | + (blocking here until answer). |
| 456 | +
|
| 457 | + This method will: |
| 458 | +
|
| 459 | + - connect using the transport to get a session |
| 460 | + - execute the GraphQL requests on the transport session |
| 461 | + - close the session and close the connection to the server |
| 462 | +
|
| 463 | + If you want to perform multiple executions, it is better to use |
| 464 | + the context manager to keep a session active. |
| 465 | +
|
| 466 | + The extra arguments passed in the method will be passed to the transport |
| 467 | + execute method. |
| 468 | + """ |
| 469 | + |
| 470 | + if isinstance(self.transport, AsyncTransport): |
| 471 | + raise NotImplementedError("Batching is not implemented for async yet.") |
| 472 | + |
| 473 | + else: # Sync transports |
| 474 | + return self.execute_batch_sync( |
| 475 | + reqs, |
| 476 | + serialize_variables=serialize_variables, |
| 477 | + parse_result=parse_result, |
| 478 | + get_execution_result=get_execution_result, |
| 479 | + **kwargs, |
| 480 | + ) |
| 481 | + |
421 | 482 | @overload
|
422 | 483 | def subscribe_async(
|
423 | 484 | self,
|
@@ -476,7 +537,6 @@ async def subscribe_async(
|
476 | 537 | ]:
|
477 | 538 | """:meta private:"""
|
478 | 539 | async with self as session:
|
479 |
| - |
480 | 540 | generator = session.subscribe(
|
481 | 541 | document,
|
482 | 542 | variable_values=variable_values,
|
@@ -600,7 +660,6 @@ def subscribe(
|
600 | 660 | pass
|
601 | 661 |
|
602 | 662 | except (KeyboardInterrupt, Exception, GeneratorExit):
|
603 |
| - |
604 | 663 | # Graceful shutdown
|
605 | 664 | asyncio.ensure_future(async_generator.aclose(), loop=loop)
|
606 | 665 |
|
@@ -661,11 +720,9 @@ async def close_async(self):
|
661 | 720 | await self.transport.close()
|
662 | 721 |
|
663 | 722 | async def __aenter__(self):
|
664 |
| - |
665 | 723 | return await self.connect_async()
|
666 | 724 |
|
667 | 725 | async def __aexit__(self, exc_type, exc, tb):
|
668 |
| - |
669 | 726 | await self.close_async()
|
670 | 727 |
|
671 | 728 | def connect_sync(self):
|
@@ -705,7 +762,6 @@ def close_sync(self):
|
705 | 762 | self.transport.close()
|
706 | 763 |
|
707 | 764 | def __enter__(self):
|
708 |
| - |
709 | 765 | return self.connect_sync()
|
710 | 766 |
|
711 | 767 | def __exit__(self, *args):
|
@@ -880,6 +936,108 @@ def execute(
|
880 | 936 |
|
881 | 937 | return result.data
|
882 | 938 |
|
| 939 | + def _execute_batch( |
| 940 | + self, |
| 941 | + reqs: List[GraphQLRequest], |
| 942 | + serialize_variables: Optional[bool] = None, |
| 943 | + parse_result: Optional[bool] = None, |
| 944 | + **kwargs, |
| 945 | + ) -> List[ExecutionResult]: |
| 946 | + """Execute multiple GraphQL requests in a batch, using |
| 947 | + the sync transport, returning a list of ExecutionResult objects. |
| 948 | +
|
| 949 | + :param reqs: List of requests that will be executed. |
| 950 | + :param serialize_variables: whether the variable values should be |
| 951 | + serialized. Used for custom scalars and/or enums. |
| 952 | + By default use the serialize_variables argument of the client. |
| 953 | + :param parse_result: Whether gql will unserialize the result. |
| 954 | + By default use the parse_results argument of the client. |
| 955 | +
|
| 956 | + The extra arguments are passed to the transport execute method.""" |
| 957 | + |
| 958 | + # Validate document |
| 959 | + if self.client.schema: |
| 960 | + for req in reqs: |
| 961 | + self.client.validate(req.document) |
| 962 | + |
| 963 | + # Parse variable values for custom scalars if requested |
| 964 | + if serialize_variables or ( |
| 965 | + serialize_variables is None and self.client.serialize_variables |
| 966 | + ): |
| 967 | + reqs = [ |
| 968 | + req.serialize_variable_values(self.client.schema) |
| 969 | + if req.variable_values is not None |
| 970 | + else req |
| 971 | + for req in reqs |
| 972 | + ] |
| 973 | + |
| 974 | + results = self.transport.execute_batch(reqs, **kwargs) |
| 975 | + |
| 976 | + # Unserialize the result if requested |
| 977 | + if self.client.schema: |
| 978 | + if parse_result or (parse_result is None and self.client.parse_results): |
| 979 | + for result in results: |
| 980 | + result.data = parse_result_fn( |
| 981 | + self.client.schema, |
| 982 | + req.document, |
| 983 | + result.data, |
| 984 | + operation_name=req.operation_name, |
| 985 | + ) |
| 986 | + |
| 987 | + return results |
| 988 | + |
| 989 | + def execute_batch( |
| 990 | + self, |
| 991 | + reqs: List[GraphQLRequest], |
| 992 | + serialize_variables: Optional[bool] = None, |
| 993 | + parse_result: Optional[bool] = None, |
| 994 | + get_execution_result: bool = False, |
| 995 | + **kwargs, |
| 996 | + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: |
| 997 | + """Execute multiple GraphQL requests in a batch, using |
| 998 | + the sync transport. This method sends the requests to the server all at once. |
| 999 | +
|
| 1000 | + Raises a TransportQueryError if an error has been returned in any |
| 1001 | + ExecutionResult. |
| 1002 | +
|
| 1003 | + :param reqs: List of requests that will be executed. |
| 1004 | + :param serialize_variables: whether the variable values should be |
| 1005 | + serialized. Used for custom scalars and/or enums. |
| 1006 | + By default use the serialize_variables argument of the client. |
| 1007 | + :param parse_result: Whether gql will unserialize the result. |
| 1008 | + By default use the parse_results argument of the client. |
| 1009 | + :param get_execution_result: return the full ExecutionResult instance instead of |
| 1010 | + only the "data" field. Necessary if you want to get the "extensions" field. |
| 1011 | +
|
| 1012 | + The extra arguments are passed to the transport execute method.""" |
| 1013 | + |
| 1014 | + # Validate and execute on the transport |
| 1015 | + results = self._execute_batch( |
| 1016 | + reqs, |
| 1017 | + serialize_variables=serialize_variables, |
| 1018 | + parse_result=parse_result, |
| 1019 | + **kwargs, |
| 1020 | + ) |
| 1021 | + |
| 1022 | + for result in results: |
| 1023 | + # Raise an error if an error is returned in the ExecutionResult object |
| 1024 | + if result.errors: |
| 1025 | + raise TransportQueryError( |
| 1026 | + str_first_element(result.errors), |
| 1027 | + errors=result.errors, |
| 1028 | + data=result.data, |
| 1029 | + extensions=result.extensions, |
| 1030 | + ) |
| 1031 | + |
| 1032 | + assert ( |
| 1033 | + result.data is not None |
| 1034 | + ), "Transport returned an ExecutionResult without data or errors" |
| 1035 | + |
| 1036 | + if get_execution_result: |
| 1037 | + return results |
| 1038 | + |
| 1039 | + return cast(List[Dict[str, Any]], [result.data for result in results]) |
| 1040 | + |
883 | 1041 | def fetch_schema(self) -> None:
|
884 | 1042 | """Fetch the GraphQL schema explicitly using introspection.
|
885 | 1043 |
|
@@ -966,7 +1124,6 @@ async def _subscribe(
|
966 | 1124 |
|
967 | 1125 | try:
|
968 | 1126 | async for result in inner_generator:
|
969 |
| - |
970 | 1127 | if self.client.schema:
|
971 | 1128 | if parse_result or (
|
972 | 1129 | parse_result is None and self.client.parse_results
|
@@ -1070,7 +1227,6 @@ async def subscribe(
|
1070 | 1227 | try:
|
1071 | 1228 | # Validate and subscribe on the transport
|
1072 | 1229 | async for result in inner_generator:
|
1073 |
| - |
1074 | 1230 | # Raise an error if an error is returned in the ExecutionResult object
|
1075 | 1231 | if result.errors:
|
1076 | 1232 | raise TransportQueryError(
|
@@ -1343,7 +1499,6 @@ async def _connection_loop(self):
|
1343 | 1499 | """
|
1344 | 1500 |
|
1345 | 1501 | while True:
|
1346 |
| - |
1347 | 1502 | # Connect to the transport with the retry decorator
|
1348 | 1503 | # By default it should keep retrying until it connect
|
1349 | 1504 | await self._connect_with_retries()
|
|
0 commit comments