Skip to content

Commit 2865994

Browse files
committed
Some clean-up in middleware implementation
1 parent a311b86 commit 2865994

File tree

7 files changed

+156
-182
lines changed

7 files changed

+156
-182
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ a query language for APIs created by Facebook.
1313

1414
The current version 1.0.1 of GraphQL-core-next is up-to-date with GraphQL.js
1515
version 14.0.2. All parts of the API are covered by an extensive test suite of
16-
currently 1606 unit tests.
16+
currently 1609 unit tests.
1717

1818

1919
## Documentation

graphql/execution/__init__.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,13 @@
55
"""
66

77
from .execute import (
8-
execute,
9-
default_field_resolver,
10-
response_path_as_list,
11-
ExecutionContext,
12-
ExecutionResult,
13-
)
8+
execute, default_field_resolver, response_path_as_list,
9+
ExecutionContext, ExecutionResult, Middleware)
1410
from .middleware import MiddlewareManager
1511
from .values import get_directive_values
1612

1713
__all__ = [
18-
"execute",
19-
"default_field_resolver",
20-
"response_path_as_list",
21-
"ExecutionContext",
22-
"ExecutionResult",
23-
"MiddlewareManager",
24-
"get_directive_values",
25-
]
14+
'execute', 'default_field_resolver', 'response_path_as_list',
15+
'ExecutionContext', 'ExecutionResult',
16+
'Middleware', 'MiddlewareManager',
17+
'get_directive_values']

graphql/execution/execute.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DocumentNode, FieldNode, FragmentDefinitionNode,
99
FragmentSpreadNode, InlineFragmentNode, OperationDefinitionNode,
1010
OperationType, SelectionSetNode)
11+
from .middleware import MiddlewareManager
1112
from ..pyutils import is_invalid, is_nullish, MaybeAwaitable
1213
from ..utilities import get_operation_root_type, type_from_ast
1314
from ..type import (
@@ -20,13 +21,11 @@
2021
is_non_null_type, is_object_type)
2122
from .values import (
2223
get_argument_values, get_directive_values, get_variable_values)
23-
from .middleware import MiddlewareManager
24-
2524

2625
__all__ = [
2726
'add_path', 'assert_valid_execution_arguments', 'default_field_resolver',
2827
'execute', 'get_field_def', 'response_path_as_list',
29-
'ExecutionResult', 'ExecutionContext']
28+
'ExecutionResult', 'ExecutionContext', 'Middleware']
3029

3130

3231
# Terminology
@@ -61,13 +60,16 @@ class ExecutionResult(NamedTuple):
6160

6261
ExecutionResult.__new__.__defaults__ = (None, None) # type: ignore
6362

63+
Middleware = Optional[Union[Iterable[Any], MiddlewareManager]]
64+
6465

6566
def execute(
6667
schema: GraphQLSchema, document: DocumentNode,
6768
root_value: Any=None, context_value: Any=None,
6869
variable_values: Dict[str, Any]=None,
69-
operation_name: str=None, field_resolver: GraphQLFieldResolver=None,
70-
middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None
70+
operation_name: str=None,
71+
field_resolver: GraphQLFieldResolver=None,
72+
middleware: Middleware=None
7173
) -> MaybeAwaitable[ExecutionResult]:
7274
"""Execute a GraphQL operation.
7375
@@ -151,7 +153,7 @@ def build(
151153
raw_variable_values: Dict[str, Any]=None,
152154
operation_name: str=None,
153155
field_resolver: GraphQLFieldResolver=None,
154-
middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None
156+
middleware: Middleware=None
155157
) -> Union[List[GraphQLError], 'ExecutionContext']:
156158
"""Build an execution context
157159
@@ -165,16 +167,16 @@ def build(
165167
has_multiple_assumed_operations = False
166168
fragments: Dict[str, FragmentDefinitionNode] = {}
167169
middleware_manager: Optional[MiddlewareManager] = None
168-
if middleware:
170+
if middleware is not None:
169171
if isinstance(middleware, Iterable):
170172
middleware_manager = MiddlewareManager(*middleware)
171173
elif isinstance(middleware, MiddlewareManager):
172174
middleware_manager = middleware
173175
else:
174176
raise TypeError(
175-
f"middlewares have to be an instance"
176-
"of MiddlewareManager. Received \"{middleware}\""
177-
)
177+
"Middleware must be passed as a sequence of functions"
178+
" or objects, or as a single MiddlewareManager object."
179+
f" Got {middleware!r} instead.")
178180

179181
for definition in document.definitions:
180182
if isinstance(definition, OperationDefinitionNode):

graphql/execution/middleware.py

+36-38
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,74 @@
1-
from typing import Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast
2-
3-
from inspect import isfunction
41
from functools import partial
2+
from inspect import isfunction
53
from itertools import chain
64

5+
from typing import (
6+
Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast)
77

8-
from ..type import GraphQLFieldResolver
9-
8+
__all__ = ['MiddlewareManager']
109

11-
__all__ = ["MiddlewareManager", "middlewares"]
12-
13-
# If the provided middleware is a class, this is the attribute we will look at
14-
MIDDLEWARE_RESOLVER_FUNCTION = "resolve"
10+
GraphQLFieldResolver = Callable[..., Any]
1511

1612

1713
class MiddlewareManager:
18-
"""MiddlewareManager helps to chain resolver functions with the provided
19-
middleware functions and classes
14+
"""Manager for the middleware chain.
15+
16+
This class helps to wrap resolver functions with the provided middleware
17+
functions and/or objects. The functions take the next middleware function
18+
as first argument. If middleware is provided as an object, it must provide
19+
a method 'resolve' that is used as the middleware function.
2020
"""
2121

22-
__slots__ = ("middlewares", "_middleware_resolvers", "_cached_resolvers")
22+
__slots__ = 'middlewares', '_middleware_resolvers', '_cached_resolvers'
2323

2424
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
25-
_middleware_resolvers: Optional[Tuple[Callable, ...]]
25+
_middleware_resolvers: Optional[Iterator[Callable]]
2626

2727
def __init__(self, *middlewares: Any) -> None:
2828
self.middlewares = middlewares
29-
if middlewares:
30-
self._middleware_resolvers = tuple(get_middleware_resolvers(middlewares))
31-
else:
32-
self.__middleware_resolvers = None
29+
self._middleware_resolvers = get_middleware_resolvers(
30+
middlewares) if middlewares else None
3331
self._cached_resolvers = {}
3432

3533
def get_field_resolver(
36-
self, field_resolver: GraphQLFieldResolver
37-
) -> GraphQLFieldResolver:
38-
"""Wraps the provided resolver returning a function that
39-
executes chains the middleware functions with the resolver function"""
34+
self, field_resolver: GraphQLFieldResolver
35+
) -> GraphQLFieldResolver:
36+
"""Wrap the provided resolver with the middleware.
37+
38+
Returns a function that chains the middleware functions with the
39+
provided resolver function
40+
"""
4041
if self._middleware_resolvers is None:
4142
return field_resolver
4243
if field_resolver not in self._cached_resolvers:
4344
self._cached_resolvers[field_resolver] = middleware_chain(
44-
field_resolver, self._middleware_resolvers
45-
)
46-
45+
field_resolver, self._middleware_resolvers)
4746
return self._cached_resolvers[field_resolver]
4847

4948

50-
middlewares = MiddlewareManager
51-
52-
53-
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
54-
"""Returns the functions related to the middleware classes or functions"""
49+
def get_middleware_resolvers(
50+
middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
51+
"""Get a list of resolver functions from a list of classes or functions."""
5552
for middleware in middlewares:
56-
# If the middleware is a function instead of a class
5753
if isfunction(middleware):
5854
yield middleware
59-
resolver_func = getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION, None)
60-
if resolver_func is not None:
61-
yield resolver_func
55+
else: # middleware provided as object with 'resolve' method
56+
resolver_func = getattr(middleware, 'resolve', None)
57+
if resolver_func is not None:
58+
yield resolver_func
6259

6360

6461
def middleware_chain(
65-
func: GraphQLFieldResolver, middlewares: Iterable[Callable]
66-
) -> GraphQLFieldResolver:
67-
"""Reduces the current function with the provided middlewares,
68-
returning a new resolver function"""
62+
func: GraphQLFieldResolver, middlewares: Iterable[Callable]
63+
) -> GraphQLFieldResolver:
64+
"""Chain the given function with the provided middlewares.
65+
66+
Returns a new resolver function that is the chain of both.
67+
"""
6968
if not middlewares:
7069
return func
7170
middlewares = chain((func,), middlewares)
7271
last_func: Optional[GraphQLFieldResolver] = None
7372
for middleware in middlewares:
7473
last_func = partial(middleware, last_func) if last_func else middleware
75-
7674
return cast(GraphQLFieldResolver, last_func)

graphql/graphql.py

+34-39
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
from asyncio import ensure_future
22
from inspect import isawaitable
3-
from typing import Any, Awaitable, Callable, Dict, Union, Optional, Iterable, cast
3+
from typing import Any, Awaitable, Callable, Dict, Union, cast
44

55
from .error import GraphQLError
6-
from .execution import execute
6+
from .execution import execute, ExecutionResult, Middleware
77
from .language import parse, Source
88
from .pyutils import MaybeAwaitable
99
from .type import GraphQLSchema, validate_schema
10-
from .execution import ExecutionResult, MiddlewareManager
1110

12-
__all__ = ["graphql", "graphql_sync"]
11+
__all__ = ['graphql', 'graphql_sync']
1312

1413

1514
async def graphql(
16-
schema: GraphQLSchema,
17-
source: Union[str, Source],
18-
root_value: Any = None,
19-
context_value: Any = None,
20-
variable_values: Dict[str, Any] = None,
21-
operation_name: str = None,
22-
field_resolver: Callable = None,
23-
middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None,
24-
) -> ExecutionResult:
15+
schema: GraphQLSchema,
16+
source: Union[str, Source],
17+
root_value: Any=None,
18+
context_value: Any=None,
19+
variable_values: Dict[str, Any]=None,
20+
operation_name: str=None,
21+
field_resolver: Callable=None,
22+
middleware: Middleware=None
23+
) -> ExecutionResult:
2524
"""Execute a GraphQL operation asynchronously.
2625
2726
This is the primary entry point function for fulfilling GraphQL operations
@@ -70,8 +69,7 @@ async def graphql(
7069
variable_values,
7170
operation_name,
7271
field_resolver,
73-
middleware,
74-
)
72+
middleware)
7573

7674
if isawaitable(result):
7775
return await cast(Awaitable[ExecutionResult], result)
@@ -80,15 +78,15 @@ async def graphql(
8078

8179

8280
def graphql_sync(
83-
schema: GraphQLSchema,
84-
source: Union[str, Source],
85-
root_value: Any = None,
86-
context_value: Any = None,
87-
variable_values: Dict[str, Any] = None,
88-
operation_name: str = None,
89-
field_resolver: Callable = None,
90-
middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None,
91-
) -> ExecutionResult:
81+
schema: GraphQLSchema,
82+
source: Union[str, Source],
83+
root_value: Any=None,
84+
context_value: Any=None,
85+
variable_values: Dict[str, Any]=None,
86+
operation_name: str=None,
87+
field_resolver: Callable=None,
88+
middleware: Middleware=None
89+
) -> ExecutionResult:
9290
"""Execute a GraphQL operation synchronously.
9391
9492
The graphql_sync function also fulfills GraphQL operations by parsing,
@@ -104,27 +102,26 @@ def graphql_sync(
104102
variable_values,
105103
operation_name,
106104
field_resolver,
107-
middleware,
108-
)
105+
middleware)
109106

110107
# Assert that the execution was synchronous.
111108
if isawaitable(result):
112109
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
113-
raise RuntimeError("GraphQL execution failed to complete synchronously.")
110+
raise RuntimeError(
111+
"GraphQL execution failed to complete synchronously.")
114112

115113
return cast(ExecutionResult, result)
116114

117115

118116
def graphql_impl(
119-
schema,
120-
source,
121-
root_value,
122-
context_value,
123-
variable_values,
124-
operation_name,
125-
field_resolver,
126-
middleware,
127-
) -> MaybeAwaitable[ExecutionResult]:
117+
schema,
118+
source,
119+
root_value,
120+
context_value,
121+
variable_values,
122+
operation_name,
123+
field_resolver,
124+
middleware) -> MaybeAwaitable[ExecutionResult]:
128125
"""Execute a query, return asynchronously only if necessary."""
129126
# Validate Schema
130127
schema_validation_errors = validate_schema(schema)
@@ -142,7 +139,6 @@ def graphql_impl(
142139

143140
# Validate
144141
from .validation import validate
145-
146142
validation_errors = validate(schema, document)
147143
if validation_errors:
148144
return ExecutionResult(data=None, errors=validation_errors)
@@ -156,5 +152,4 @@ def graphql_impl(
156152
variable_values,
157153
operation_name,
158154
field_resolver,
159-
middleware,
160-
)
155+
middleware)

0 commit comments

Comments
 (0)