diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py index 9581e12..9d28f02 100644 --- a/graphql_server/aiohttp/graphqlview.py +++ b/graphql_server/aiohttp/graphqlview.py @@ -1,12 +1,14 @@ import copy from collections.abc import MutableMapping from functools import partial +from typing import List from aiohttp import web from graphql import GraphQLError from graphql.type.schema import GraphQLSchema from graphql_server import ( + GraphQLParams, HttpQueryError, encode_execution_results, format_error_default, @@ -14,8 +16,11 @@ load_json_body, run_http_query, ) - -from .render_graphiql import render_graphiql +from graphql_server.render_graphiql import ( + GraphiQLConfig, + GraphiQLData, + render_graphiql_async, +) class GraphQLView: @@ -26,12 +31,14 @@ class GraphQLView: graphiql = False graphiql_version = None graphiql_template = None + graphiql_html_title = None middleware = None batch = False jinja_env = None max_age = 86400 enable_async = False subscriptions = None + headers = None accepted_methods = ["GET", "POST", "PUT", "DELETE"] @@ -88,16 +95,6 @@ async def parse_body(self, request): return {} - def render_graphiql(self, params, result): - return render_graphiql( - jinja_env=self.jinja_env, - params=params, - result=result, - graphiql_version=self.graphiql_version, - graphiql_template=self.graphiql_template, - subscriptions=self.subscriptions, - ) - # TODO: # use this method to replace flask and sanic # checks as this is equivalent to `should_display_graphiql` and @@ -135,6 +132,7 @@ async def __call__(self, request): if request_method == "options": return self.process_preflight(request) + all_params: List[GraphQLParams] execution_results, all_params = run_http_query( self.schema, request_method, @@ -162,7 +160,24 @@ async def __call__(self, request): ) if is_graphiql: - return await self.render_graphiql(params=all_params[0], result=result) + graphiql_data = GraphiQLData( + result=result, + query=getattr(all_params[0], "query"), + variables=getattr(all_params[0], "variables"), + operation_name=getattr(all_params[0], "operation_name"), + subscription_url=self.subscriptions, + headers=self.headers, + ) + graphiql_config = GraphiQLConfig( + graphiql_version=self.graphiql_version, + graphiql_template=self.graphiql_template, + graphiql_html_title=self.graphiql_html_title, + jinja_env=self.jinja_env, + ) + source = await render_graphiql_async( + data=graphiql_data, config=graphiql_config + ) + return web.Response(text=source, content_type="text/html") return web.Response( text=result, status=status_code, content_type="application/json", diff --git a/graphql_server/aiohttp/render_graphiql.py b/graphql_server/aiohttp/render_graphiql.py deleted file mode 100644 index 9da47d3..0000000 --- a/graphql_server/aiohttp/render_graphiql.py +++ /dev/null @@ -1,208 +0,0 @@ -import json -import re - -from aiohttp import web - -GRAPHIQL_VERSION = "0.17.5" - -TEMPLATE = """ - - - - - - - - - - - - - - - - -""" - - -def escape_js_value(value): - quotation = False - if value.startswith('"') and value.endswith('"'): - quotation = True - value = value[1:-1] - - value = value.replace("\\\\n", "\\\\\\n").replace("\\n", "\\\\n") - if quotation: - value = '"' + value.replace('\\\\"', '"').replace('"', '\\"') + '"' - - return value - - -def process_var(template, name, value, jsonify=False): - pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}" - if jsonify and value not in ["null", "undefined"]: - value = json.dumps(value) - value = escape_js_value(value) - - return re.sub(pattern, value, template) - - -def simple_renderer(template, **values): - replace = ["graphiql_version", "subscriptions"] - replace_jsonify = ["query", "result", "variables", "operation_name"] - - for rep in replace: - template = process_var(template, rep, values.get(rep, "")) - - for rep in replace_jsonify: - template = process_var(template, rep, values.get(rep, ""), True) - - return template - - -async def render_graphiql( - jinja_env=None, - graphiql_version=None, - graphiql_template=None, - params=None, - result=None, - subscriptions=None, -): - graphiql_version = graphiql_version or GRAPHIQL_VERSION - template = graphiql_template or TEMPLATE - template_vars = { - "graphiql_version": graphiql_version, - "query": params and params.query, - "variables": params and params.variables, - "operation_name": params and params.operation_name, - "result": result, - "subscriptions": subscriptions or "", - } - - if jinja_env: - template = jinja_env.from_string(template) - if jinja_env.is_async: - source = await template.render_async(**template_vars) - else: - source = template.render(**template_vars) - else: - source = simple_renderer(template, **template_vars) - - return web.Response(text=source, content_type="text/html") diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 1a2f9af..9108a41 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -1,11 +1,13 @@ from functools import partial +from typing import List -from flask import Response, request +from flask import Response, render_template_string, request from flask.views import View from graphql.error import GraphQLError from graphql.type.schema import GraphQLSchema from graphql_server import ( + GraphQLParams, HttpQueryError, encode_execution_results, format_error_default, @@ -13,8 +15,11 @@ load_json_body, run_http_query, ) - -from .render_graphiql import render_graphiql +from graphql_server.render_graphiql import ( + GraphiQLConfig, + GraphiQLData, + render_graphiql_sync, +) class GraphQLView(View): @@ -27,6 +32,8 @@ class GraphQLView(View): graphiql_html_title = None middleware = None batch = False + subscriptions = None + headers = None methods = ["GET", "POST", "PUT", "DELETE"] @@ -50,15 +57,6 @@ def get_context_value(self): def get_middleware(self): return self.middleware - def render_graphiql(self, params, result): - return render_graphiql( - params=params, - result=result, - graphiql_version=self.graphiql_version, - graphiql_template=self.graphiql_template, - graphiql_html_title=self.graphiql_html_title, - ) - format_error = staticmethod(format_error_default) encode = staticmethod(json_encode) @@ -72,6 +70,7 @@ def dispatch_request(self): pretty = self.pretty or show_graphiql or request.args.get("pretty") + all_params: List[GraphQLParams] execution_results, all_params = run_http_query( self.schema, request_method, @@ -88,11 +87,28 @@ def dispatch_request(self): execution_results, is_batch=isinstance(data, list), format_error=self.format_error, - encode=partial(self.encode, pretty=pretty), + encode=partial(self.encode, pretty=pretty), # noqa ) if show_graphiql: - return self.render_graphiql(params=all_params[0], result=result) + graphiql_data = GraphiQLData( + result=result, + query=getattr(all_params[0], "query"), + variables=getattr(all_params[0], "variables"), + operation_name=getattr(all_params[0], "operation_name"), + subscription_url=self.subscriptions, + headers=self.headers, + ) + graphiql_config = GraphiQLConfig( + graphiql_version=self.graphiql_version, + graphiql_template=self.graphiql_template, + graphiql_html_title=self.graphiql_html_title, + jinja_env=None, + ) + source = render_graphiql_sync( + data=graphiql_data, config=graphiql_config + ) + return render_template_string(source) return Response(result, status=status_code, content_type="application/json") diff --git a/graphql_server/flask/render_graphiql.py b/graphql_server/flask/render_graphiql.py deleted file mode 100644 index d395d44..0000000 --- a/graphql_server/flask/render_graphiql.py +++ /dev/null @@ -1,148 +0,0 @@ -from flask import render_template_string - -GRAPHIQL_VERSION = "0.11.11" - -TEMPLATE = """ - - - - {{graphiql_html_title|default("GraphiQL", true)}} - - - - - - - - - - - -""" - - -def render_graphiql( - params, - result, - graphiql_version=None, - graphiql_template=None, - graphiql_html_title=None, -): - graphiql_version = graphiql_version or GRAPHIQL_VERSION - template = graphiql_template or TEMPLATE - - return render_template_string( - template, - graphiql_version=graphiql_version, - graphiql_html_title=graphiql_html_title, - result=result, - params=params, - ) diff --git a/graphql_server/render_graphiql.py b/graphql_server/render_graphiql.py new file mode 100644 index 0000000..8ae4107 --- /dev/null +++ b/graphql_server/render_graphiql.py @@ -0,0 +1,330 @@ +"""Based on (express-graphql)[https://github.com/graphql/express-graphql/blob/master/src/renderGraphiQL.js] and +(subscriptions-transport-ws)[https://github.com/apollographql/subscriptions-transport-ws]""" +import json +import re +from typing import Any, Dict, Optional, Tuple + +from jinja2 import Environment +from typing_extensions import TypedDict + +GRAPHIQL_VERSION = "1.0.3" + +GRAPHIQL_TEMPLATE = """ + + + + + {{graphiql_html_title}} + + + + + + + + + + + + + + +
Loading...
+ + +""" + + +class GraphiQLData(TypedDict): + """GraphiQL ReactDom Data + + Has the following attributes: + + subscription_url + The GraphiQL socket endpoint for using subscriptions in graphql-ws. + headers + An optional GraphQL string to use as the initial displayed request headers, + if None is provided, the stored headers will be used. + """ + + query: Optional[str] + variables: Optional[str] + operation_name: Optional[str] + result: Optional[str] + subscription_url: Optional[str] + headers: Optional[str] + + +class GraphiQLConfig(TypedDict): + """GraphiQL Extra Config + + Has the following attributes: + + graphiql_version + The version of the provided GraphiQL package. + graphiql_template + Inject a Jinja template string to customize GraphiQL. + graphiql_html_title + Replace the default html title on the GraphiQL. + jinja_env + Sets jinja environment to be used to process GraphiQL template. + If Jinja’s async mode is enabled (by enable_async=True), + uses Template.render_async instead of Template.render. + If environment is not set, fallbacks to simple regex-based renderer. + """ + + graphiql_version: Optional[str] + graphiql_template: Optional[str] + graphiql_html_title: Optional[str] + jinja_env: Optional[Environment] + + +class GraphiQLOptions(TypedDict): + """GraphiQL options to display on the UI. + + Has the following attributes: + + default_query + An optional GraphQL string to use when no query is provided and no stored + query exists from a previous session. If undefined is provided, GraphiQL + will use its own default query. + header_editor_enabled + An optional boolean which enables the header editor when true. + Defaults to false. + should_persist_headers + An optional boolean which enables to persist headers to storage when true. + Defaults to false. + """ + + default_query: Optional[str] + header_editor_enabled: Optional[bool] + should_persist_headers: Optional[bool] + + +def escape_js_value(value: Any) -> Any: + quotation = False + if value.startswith('"') and value.endswith('"'): + quotation = True + value = value[1 : len(value) - 1] + + value = value.replace("\\\\n", "\\\\\\n").replace("\\n", "\\\\n") + if quotation: + value = '"' + value.replace('\\\\"', '"').replace('"', '\\"') + '"' + + return value + + +def process_var(template: str, name: str, value: Any, jsonify=False) -> str: + pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}" + if jsonify and value not in ["null", "undefined"]: + value = json.dumps(value) + value = escape_js_value(value) + + return re.sub(pattern, value, template) + + +def simple_renderer(template: str, **values: Dict[str, Any]) -> str: + replace = [ + "graphiql_version", + "graphiql_html_title", + "subscription_url", + "header_editor_enabled", + "should_persist_headers", + ] + replace_jsonify = [ + "query", + "result", + "variables", + "operation_name", + "default_query", + "headers", + ] + + for r in replace: + template = process_var(template, r, values.get(r, "")) + + for r in replace_jsonify: + template = process_var(template, r, values.get(r, ""), True) + + return template + + +def _render_graphiql( + data: GraphiQLData, + config: GraphiQLConfig, + options: Optional[GraphiQLOptions] = None, +) -> Tuple[str, Dict[str, Any]]: + """When render_graphiql receives a request which does not Accept JSON, but does + Accept HTML, it may present GraphiQL, the in-browser GraphQL explorer IDE. + When shown, it will be pre-populated with the result of having executed + the requested query. + """ + graphiql_version = config.get("graphiql_version") or GRAPHIQL_VERSION + graphiql_template = config.get("graphiql_template") or GRAPHIQL_TEMPLATE + graphiql_html_title = config.get("graphiql_html_title") or "GraphiQL" + + template_vars: Dict[str, Any] = { + "graphiql_version": graphiql_version, + "graphiql_html_title": graphiql_html_title, + "query": data.get("query"), + "variables": data.get("variables"), + "operation_name": data.get("operation_name"), + "result": data.get("result"), + "subscription_url": data.get("subscription_url") or "", + "headers": data.get("headers") or "", + "default_query": options and options.get("default_query") or "", + "header_editor_enabled": options + and options.get("header_editor_enabled") + or "true", + "should_persist_headers": options + and options.get("should_persist_headers") + or "false", + } + + return graphiql_template, template_vars + + +async def render_graphiql_async( + data: GraphiQLData, + config: GraphiQLConfig, + options: Optional[GraphiQLOptions] = None, +) -> str: + graphiql_template, template_vars = _render_graphiql(data, config, options) + jinja_env: Optional[Environment] = config.get("jinja_env") + + if jinja_env: + # This method returns a Template. See https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Template + template = jinja_env.from_string(graphiql_template) + if jinja_env.is_async: # type: ignore + source = await template.render_async(**template_vars) + else: + source = template.render(**template_vars) + else: + source = simple_renderer(graphiql_template, **template_vars) + return source + + +def render_graphiql_sync( + data: GraphiQLData, + config: GraphiQLConfig, + options: Optional[GraphiQLOptions] = None, +) -> str: + graphiql_template, template_vars = _render_graphiql(data, config, options) + + source = simple_renderer(graphiql_template, **template_vars) + return source diff --git a/graphql_server/sanic/graphqlview.py b/graphql_server/sanic/graphqlview.py index fd22af2..8e2c7b8 100644 --- a/graphql_server/sanic/graphqlview.py +++ b/graphql_server/sanic/graphqlview.py @@ -2,13 +2,15 @@ from cgi import parse_header from collections.abc import MutableMapping from functools import partial +from typing import List from graphql import GraphQLError from graphql.type.schema import GraphQLSchema -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, html from sanic.views import HTTPMethodView from graphql_server import ( + GraphQLParams, HttpQueryError, encode_execution_results, format_error_default, @@ -16,8 +18,11 @@ load_json_body, run_http_query, ) - -from .render_graphiql import render_graphiql +from graphql_server.render_graphiql import ( + GraphiQLConfig, + GraphiQLData, + render_graphiql_async, +) class GraphQLView(HTTPMethodView): @@ -28,11 +33,14 @@ class GraphQLView(HTTPMethodView): graphiql = False graphiql_version = None graphiql_template = None + graphiql_html_title = None middleware = None batch = False jinja_env = None max_age = 86400 enable_async = False + subscriptions = None + headers = None methods = ["GET", "POST", "PUT", "DELETE"] @@ -62,15 +70,6 @@ def get_context(self, request): def get_middleware(self): return self.middleware - async def render_graphiql(self, params, result): - return await render_graphiql( - jinja_env=self.jinja_env, - params=params, - result=result, - graphiql_version=self.graphiql_version, - graphiql_template=self.graphiql_template, - ) - format_error = staticmethod(format_error_default) encode = staticmethod(json_encode) @@ -87,6 +86,7 @@ async def dispatch_request(self, request, *args, **kwargs): pretty = self.pretty or show_graphiql or request.args.get("pretty") if request_method != "options": + all_params: List[GraphQLParams] execution_results, all_params = run_http_query( self.schema, request_method, @@ -113,9 +113,24 @@ async def dispatch_request(self, request, *args, **kwargs): ) if show_graphiql: - return await self.render_graphiql( - params=all_params[0], result=result + graphiql_data = GraphiQLData( + result=result, + query=getattr(all_params[0], "query"), + variables=getattr(all_params[0], "variables"), + operation_name=getattr(all_params[0], "operation_name"), + subscription_url=self.subscriptions, + headers=self.headers, + ) + graphiql_config = GraphiQLConfig( + graphiql_version=self.graphiql_version, + graphiql_template=self.graphiql_template, + graphiql_html_title=self.graphiql_html_title, + jinja_env=self.jinja_env, + ) + source = await render_graphiql_async( + data=graphiql_data, config=graphiql_config ) + return html(source) return HTTPResponse( result, status=status_code, content_type="application/json" diff --git a/graphql_server/sanic/render_graphiql.py b/graphql_server/sanic/render_graphiql.py deleted file mode 100644 index ca21ee3..0000000 --- a/graphql_server/sanic/render_graphiql.py +++ /dev/null @@ -1,185 +0,0 @@ -import json -import re - -from sanic.response import html - -GRAPHIQL_VERSION = "0.7.1" - -TEMPLATE = """ - - - - - - - - - - - - - - -""" - - -def escape_js_value(value): - quotation = False - if value.startswith('"') and value.endswith('"'): - quotation = True - value = value[1 : len(value) - 1] - - value = value.replace("\\\\n", "\\\\\\n").replace("\\n", "\\\\n") - if quotation: - value = '"' + value.replace('\\\\"', '"').replace('"', '\\"') + '"' - - return value - - -def process_var(template, name, value, jsonify=False): - pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}" - if jsonify and value not in ["null", "undefined"]: - value = json.dumps(value) - value = escape_js_value(value) - - return re.sub(pattern, value, template) - - -def simple_renderer(template, **values): - replace = ["graphiql_version"] - replace_jsonify = ["query", "result", "variables", "operation_name"] - - for r in replace: - template = process_var(template, r, values.get(r, "")) - - for r in replace_jsonify: - template = process_var(template, r, values.get(r, ""), True) - - return template - - -async def render_graphiql( - jinja_env=None, - graphiql_version=None, - graphiql_template=None, - params=None, - result=None, -): - graphiql_version = graphiql_version or GRAPHIQL_VERSION - template = graphiql_template or TEMPLATE - template_vars = { - "graphiql_version": graphiql_version, - "query": params and params.query, - "variables": params and params.variables, - "operation_name": params and params.operation_name, - "result": result, - } - - if jinja_env: - template = jinja_env.from_string(template) - if jinja_env.is_async: - source = await template.render_async(**template_vars) - else: - source = template.render(**template_vars) - else: - source = simple_renderer(template, **template_vars) - - return html(source) diff --git a/graphql_server/webob/graphqlview.py b/graphql_server/webob/graphqlview.py index a7cec7a..6a32c5b 100644 --- a/graphql_server/webob/graphqlview.py +++ b/graphql_server/webob/graphqlview.py @@ -1,12 +1,14 @@ import copy from collections.abc import MutableMapping from functools import partial +from typing import List from graphql.error import GraphQLError from graphql.type.schema import GraphQLSchema from webob import Response from graphql_server import ( + GraphQLParams, HttpQueryError, encode_execution_results, format_error_default, @@ -14,8 +16,11 @@ load_json_body, run_http_query, ) - -from .render_graphiql import render_graphiql +from graphql_server.render_graphiql import ( + GraphiQLConfig, + GraphiQLData, + render_graphiql_sync, +) class GraphQLView: @@ -27,9 +32,12 @@ class GraphQLView: graphiql = False graphiql_version = None graphiql_template = None + graphiql_html_title = None middleware = None batch = False enable_async = False + subscriptions = None + headers = None charset = "UTF-8" def __init__(self, **kwargs): @@ -73,6 +81,7 @@ def dispatch_request(self, request): pretty = self.pretty or show_graphiql or request.params.get("pretty") + all_params: List[GraphQLParams] execution_results, all_params = run_http_query( self.schema, request_method, @@ -94,8 +103,22 @@ def dispatch_request(self, request): ) if show_graphiql: + graphiql_data = GraphiQLData( + result=result, + query=getattr(all_params[0], "query"), + variables=getattr(all_params[0], "variables"), + operation_name=getattr(all_params[0], "operation_name"), + subscription_url=self.subscriptions, + headers=self.headers, + ) + graphiql_config = GraphiQLConfig( + graphiql_version=self.graphiql_version, + graphiql_template=self.graphiql_template, + graphiql_html_title=self.graphiql_html_title, + jinja_env=None, + ) return Response( - render_graphiql(params=all_params[0], result=result), + render_graphiql_sync(data=graphiql_data, config=graphiql_config), charset=self.charset, content_type="text/html", ) diff --git a/graphql_server/webob/render_graphiql.py b/graphql_server/webob/render_graphiql.py deleted file mode 100644 index 5e9c735..0000000 --- a/graphql_server/webob/render_graphiql.py +++ /dev/null @@ -1,172 +0,0 @@ -import json -import re - -GRAPHIQL_VERSION = "0.17.5" - -TEMPLATE = """ - - - - - - - - - - - - - - -""" - - -def escape_js_value(value): - quotation = False - if value.startswith('"') and value.endswith('"'): - quotation = True - value = value[1 : len(value) - 1] - - value = value.replace("\\\\n", "\\\\\\n").replace("\\n", "\\\\n") - if quotation: - value = '"' + value.replace('\\\\"', '"').replace('"', '\\"') + '"' - - return value - - -def process_var(template, name, value, jsonify=False): - pattern = r"{{\s*" + name + r"(\s*|[^}]+)*\s*}}" - if jsonify and value not in ["null", "undefined"]: - value = json.dumps(value) - value = escape_js_value(value) - - return re.sub(pattern, value, template) - - -def simple_renderer(template, **values): - replace = ["graphiql_version"] - replace_jsonify = ["query", "result", "variables", "operation_name"] - - for r in replace: - template = process_var(template, r, values.get(r, "")) - - for r in replace_jsonify: - template = process_var(template, r, values.get(r, ""), True) - - return template - - -def render_graphiql( - graphiql_version=None, graphiql_template=None, params=None, result=None, -): - graphiql_version = graphiql_version or GRAPHIQL_VERSION - template = graphiql_template or TEMPLATE - - template_vars = { - "graphiql_version": graphiql_version, - "query": params and params.query, - "variables": params and params.variables, - "operation_name": params and params.operation_name, - "result": result, - } - - source = simple_renderer(template, **template_vars) - return source diff --git a/setup.py b/setup.py index 8977038..4c6aa58 100644 --- a/setup.py +++ b/setup.py @@ -2,10 +2,12 @@ install_requires = [ "graphql-core>=3.1.0,<4", + "typing-extensions>=3.7.4,<4" ] tests_requires = [ - "pytest>=5.3,<5.4", + "pytest>=5.4,<5.5", + "pytest-asyncio>=0.11.0", "pytest-cov>=2.8,<3", "aiohttp>=3.5.0,<4", "Jinja2>=2.10.1,<3", diff --git a/tests/aiohttp/test_graphiqlview.py b/tests/aiohttp/test_graphiqlview.py index dfe442a..a4a7a26 100644 --- a/tests/aiohttp/test_graphiqlview.py +++ b/tests/aiohttp/test_graphiqlview.py @@ -70,13 +70,6 @@ async def test_graphiql_jinja_renderer_async(self, app, client, pretty_response) assert response.status == 200 assert pretty_response in await response.text() - async def test_graphiql_jinja_renderer_sync(self, app, client, pretty_response): - response = client.get( - url_string(query="{test}"), headers={"Accept": "text/html"}, - ) - assert response.status == 200 - assert pretty_response in response.text() - @pytest.mark.asyncio async def test_graphiql_html_is_not_accepted(client): @@ -97,7 +90,7 @@ async def test_graphiql_get_mutation(app, client): @pytest.mark.asyncio @pytest.mark.parametrize("app", [create_app(graphiql=True)]) -async def test_graphiql_get_subscriptions(client): +async def test_graphiql_get_subscriptions(app, client): response = await client.get( url_string( query="subscription TestSubscriptions { subscriptionsTest { test } }" diff --git a/tox.ini b/tox.ini index 2453c8b..35edfc5 100644 --- a/tox.ini +++ b/tox.ini @@ -14,7 +14,7 @@ whitelist_externals = python commands = pip install -U setuptools - pytest --cov-report=term-missing --cov=graphql_server tests {posargs} + pytest tests --cov-report=term-missing --cov=graphql_server {posargs} [testenv:black] basepython=python3.7