Skip to content

refactor: graphiql template shared across servers #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions graphql_server/aiohttp/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import copy
from collections.abc import MutableMapping
from functools import partial
from typing import List

from aiohttp import web
from graphql import GraphQLError
from graphql.type.schema import GraphQLSchema

from graphql_server import (
GraphQLParams,
HttpQueryError,
encode_execution_results,
format_error_default,
json_encode,
load_json_body,
run_http_query,
)

from .render_graphiql import render_graphiql
from graphql_server.render_graphiql import (
GraphiQLConfig,
GraphiQLData,
render_graphiql_async,
)


class GraphQLView:
@@ -26,12 +31,14 @@ class GraphQLView:
graphiql = False
graphiql_version = None
graphiql_template = None
graphiql_html_title = None
middleware = None
batch = False
jinja_env = None
max_age = 86400
enable_async = False
subscriptions = None
headers = None

accepted_methods = ["GET", "POST", "PUT", "DELETE"]

@@ -88,16 +95,6 @@ async def parse_body(self, request):

return {}

def render_graphiql(self, params, result):
return render_graphiql(
jinja_env=self.jinja_env,
params=params,
result=result,
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
subscriptions=self.subscriptions,
)

# TODO:
# use this method to replace flask and sanic
# checks as this is equivalent to `should_display_graphiql` and
@@ -135,6 +132,7 @@ async def __call__(self, request):
if request_method == "options":
return self.process_preflight(request)

all_params: List[GraphQLParams]
execution_results, all_params = run_http_query(
self.schema,
request_method,
@@ -162,7 +160,24 @@ async def __call__(self, request):
)

if is_graphiql:
return await self.render_graphiql(params=all_params[0], result=result)
graphiql_data = GraphiQLData(
result=result,
query=getattr(all_params[0], "query"),
variables=getattr(all_params[0], "variables"),
operation_name=getattr(all_params[0], "operation_name"),
subscription_url=self.subscriptions,
headers=self.headers,
)
graphiql_config = GraphiQLConfig(
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
graphiql_html_title=self.graphiql_html_title,
jinja_env=self.jinja_env,
)
source = await render_graphiql_async(
data=graphiql_data, config=graphiql_config
)
return web.Response(text=source, content_type="text/html")

return web.Response(
text=result, status=status_code, content_type="application/json",
208 changes: 0 additions & 208 deletions graphql_server/aiohttp/render_graphiql.py

This file was deleted.

44 changes: 30 additions & 14 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from functools import partial
from typing import List

from flask import Response, request
from flask import Response, render_template_string, request
from flask.views import View
from graphql.error import GraphQLError
from graphql.type.schema import GraphQLSchema

from graphql_server import (
GraphQLParams,
HttpQueryError,
encode_execution_results,
format_error_default,
json_encode,
load_json_body,
run_http_query,
)

from .render_graphiql import render_graphiql
from graphql_server.render_graphiql import (
GraphiQLConfig,
GraphiQLData,
render_graphiql_sync,
)


class GraphQLView(View):
@@ -27,6 +32,8 @@ class GraphQLView(View):
graphiql_html_title = None
middleware = None
batch = False
subscriptions = None
headers = None

methods = ["GET", "POST", "PUT", "DELETE"]

@@ -50,15 +57,6 @@ def get_context_value(self):
def get_middleware(self):
return self.middleware

def render_graphiql(self, params, result):
return render_graphiql(
params=params,
result=result,
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
graphiql_html_title=self.graphiql_html_title,
)

format_error = staticmethod(format_error_default)
encode = staticmethod(json_encode)

@@ -72,6 +70,7 @@ def dispatch_request(self):

pretty = self.pretty or show_graphiql or request.args.get("pretty")

all_params: List[GraphQLParams]
execution_results, all_params = run_http_query(
self.schema,
request_method,
@@ -88,11 +87,28 @@ def dispatch_request(self):
execution_results,
is_batch=isinstance(data, list),
format_error=self.format_error,
encode=partial(self.encode, pretty=pretty),
encode=partial(self.encode, pretty=pretty), # noqa
)

if show_graphiql:
return self.render_graphiql(params=all_params[0], result=result)
graphiql_data = GraphiQLData(
result=result,
query=getattr(all_params[0], "query"),
variables=getattr(all_params[0], "variables"),
operation_name=getattr(all_params[0], "operation_name"),
subscription_url=self.subscriptions,
headers=self.headers,
)
graphiql_config = GraphiQLConfig(
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
graphiql_html_title=self.graphiql_html_title,
jinja_env=None,
)
source = render_graphiql_sync(
data=graphiql_data, config=graphiql_config
)
return render_template_string(source)

return Response(result, status=status_code, content_type="application/json")

Loading