diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py index fa4d998..b5891d1 100644 --- a/graphql_server/aiohttp/graphqlview.py +++ b/graphql_server/aiohttp/graphqlview.py @@ -1,10 +1,12 @@ +import asyncio import copy from collections.abc import MutableMapping from functools import partial from typing import List from aiohttp import web -from graphql import ExecutionResult, GraphQLError, specified_rules +from graphql import GraphQLError, specified_rules +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from graphql_server import ( @@ -22,6 +24,7 @@ GraphiQLOptions, render_graphiql_async, ) +from graphql_server.utils import wrap_in_async class GraphQLView: @@ -166,10 +169,14 @@ async def __call__(self, request): ) exec_res = ( - [ - ex if ex is None or isinstance(ex, ExecutionResult) else await ex - for ex in execution_results - ] + await asyncio.gather( + *( + ex + if ex is not None and is_awaitable(ex) + else wrap_in_async(lambda: ex)() + for ex in execution_results + ) + ) if self.enable_async else execution_results ) diff --git a/graphql_server/quart/graphqlview.py b/graphql_server/quart/graphqlview.py index 2ac624b..d7b209f 100644 --- a/graphql_server/quart/graphqlview.py +++ b/graphql_server/quart/graphqlview.py @@ -1,10 +1,12 @@ +import asyncio import copy from collections.abc import MutableMapping from functools import partial from typing import List -from graphql import ExecutionResult, specified_rules +from graphql import specified_rules from graphql.error import GraphQLError +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from quart import Response, render_template_string, request from quart.views import View @@ -24,6 +26,7 @@ GraphiQLOptions, render_graphiql_sync, ) +from graphql_server.utils import wrap_in_async class GraphQLView(View): @@ -113,10 +116,14 @@ async def dispatch_request(self): execution_context_class=self.get_execution_context_class(), ) exec_res = ( - [ - ex if ex is None or isinstance(ex, ExecutionResult) else await ex - for ex in execution_results - ] + await asyncio.gather( + *( + ex + if ex is not None and is_awaitable(ex) + else wrap_in_async(lambda: ex)() + for ex in execution_results + ) + ) if self.enable_async else execution_results ) diff --git a/graphql_server/sanic/graphqlview.py b/graphql_server/sanic/graphqlview.py index 7bea500..814d489 100644 --- a/graphql_server/sanic/graphqlview.py +++ b/graphql_server/sanic/graphqlview.py @@ -1,10 +1,12 @@ +import asyncio import copy from cgi import parse_header from collections.abc import MutableMapping from functools import partial from typing import List -from graphql import ExecutionResult, GraphQLError, specified_rules +from graphql import GraphQLError, specified_rules +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from sanic.response import HTTPResponse, html from sanic.views import HTTPMethodView @@ -24,6 +26,7 @@ GraphiQLOptions, render_graphiql_async, ) +from graphql_server.utils import wrap_in_async class GraphQLView(HTTPMethodView): @@ -119,12 +122,14 @@ async def __handle_request(self, request, *args, **kwargs): execution_context_class=self.get_execution_context_class(), ) exec_res = ( - [ - ex - if ex is None or isinstance(ex, ExecutionResult) - else await ex - for ex in execution_results - ] + await asyncio.gather( + *( + ex + if ex is not None and is_awaitable(ex) + else wrap_in_async(lambda: ex)() + for ex in execution_results + ) + ) if self.enable_async else execution_results ) diff --git a/graphql_server/utils.py b/graphql_server/utils.py new file mode 100644 index 0000000..c52a24b --- /dev/null +++ b/graphql_server/utils.py @@ -0,0 +1,25 @@ +import sys +from typing import Awaitable, Callable, TypeVar + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + + +__all__ = ["wrap_in_async"] + +P = ParamSpec("P") +R = TypeVar("R") + + +def wrap_in_async(f: Callable[P, R]) -> Callable[P, Awaitable[R]]: + """Convert a sync callable (normal def or lambda) to a coroutine (async def). + + This is similar to asyncio.coroutine which was deprecated in Python 3.8. + """ + + async def f_async(*args: P.args, **kwargs: P.kwargs) -> R: + return f(*args, **kwargs) + + return f_async diff --git a/tests/quart/app.py b/tests/quart/app.py index 2313f99..adfce41 100644 --- a/tests/quart/app.py +++ b/tests/quart/app.py @@ -4,11 +4,11 @@ from tests.quart.schema import Schema -def create_app(path="/graphql", **kwargs): +def create_app(path="/graphql", schema=Schema, **kwargs): server = Quart(__name__) server.debug = True server.add_url_rule( - path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs) + path, view_func=GraphQLView.as_view("graphql", schema=schema, **kwargs) ) return server diff --git a/tests/quart/schema.py b/tests/quart/schema.py index eb51e26..7aca9d8 100644 --- a/tests/quart/schema.py +++ b/tests/quart/schema.py @@ -1,3 +1,5 @@ +import asyncio + from graphql.type.definition import ( GraphQLArgument, GraphQLField, @@ -12,6 +14,7 @@ def resolve_raises(*_): raise Exception("Throws!") +# Sync schema QueryRootType = GraphQLObjectType( name="QueryRoot", fields={ @@ -36,7 +39,7 @@ def resolve_raises(*_): "test": GraphQLField( type_=GraphQLString, args={"who": GraphQLArgument(GraphQLString)}, - resolve=lambda obj, info, who="World": "Hello %s" % who, + resolve=lambda obj, info, who="World": f"Hello {who}", ), }, ) @@ -49,3 +52,48 @@ def resolve_raises(*_): ) Schema = GraphQLSchema(QueryRootType, MutationRootType) + + +# Schema with async methods +async def resolver_field_async_1(_obj, info): + await asyncio.sleep(0.001) + return "hey" + + +async def resolver_field_async_2(_obj, info): + await asyncio.sleep(0.003) + return "hey2" + + +def resolver_field_sync(_obj, info): + return "hey3" + + +AsyncQueryType = GraphQLObjectType( + name="AsyncQueryType", + fields={ + "a": GraphQLField(GraphQLString, resolve=resolver_field_async_1), + "b": GraphQLField(GraphQLString, resolve=resolver_field_async_2), + "c": GraphQLField(GraphQLString, resolve=resolver_field_sync), + }, +) + + +def resolver_field_sync_1(_obj, info): + return "synced_one" + + +def resolver_field_sync_2(_obj, info): + return "synced_two" + + +SyncQueryType = GraphQLObjectType( + "SyncQueryType", + { + "a": GraphQLField(GraphQLString, resolve=resolver_field_sync_1), + "b": GraphQLField(GraphQLString, resolve=resolver_field_sync_2), + }, +) + +AsyncSchema = GraphQLSchema(AsyncQueryType) +SyncSchema = GraphQLSchema(SyncQueryType) diff --git a/tests/quart/test_graphqlview.py b/tests/quart/test_graphqlview.py index d0da414..9b2daa2 100644 --- a/tests/quart/test_graphqlview.py +++ b/tests/quart/test_graphqlview.py @@ -9,6 +9,7 @@ from ..utils import RepeatExecutionContext from .app import create_app +from .schema import AsyncSchema @pytest.fixture @@ -736,6 +737,20 @@ async def test_batch_allows_post_with_operation_name( ] +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(schema=AsyncSchema, enable_async=True)]) +async def test_async_schema(app, client): + response = await execute_client( + app, + client, + query="{a,b,c}", + ) + + assert response.status_code == 200 + result = await response.get_data(as_text=True) + assert response_json(result) == {"data": {"a": "hey", "b": "hey2", "c": "hey3"}} + + @pytest.mark.asyncio @pytest.mark.parametrize( "app", [create_app(execution_context_class=RepeatExecutionContext)]