Skip to content

Added batch support #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ from flask_graphql import GraphQLView

app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True))

# Optional, for adding batch query support (used in Apollo-Client)
app.add_url_rule('/graphql/batch', view_func=GraphQLView.as_view('graphql', schema=schema, batch=True))
```

This will add `/graphql` and `/graphiql` endpoints to your app.
Expand Down
77 changes: 51 additions & 26 deletions flask_graphql/graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GraphQLView(View):
graphiql_version = None
graphiql_template = None
middleware = None
batch = False

methods = ['GET', 'POST', 'PUT', 'DELETE']

Expand All @@ -41,6 +42,7 @@ def __init__(self, **kwargs):
if hasattr(self, key):
setattr(self, key, value)

assert not all((self.graphiql, self.batch)), 'Use either graphiql or batch processing'
assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'

# noinspection PyUnusedLocal
Expand All @@ -66,33 +68,15 @@ def dispatch_request(self):
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(data)

query, variables, operation_name = self.get_graphql_params(request, data)

execution_result = self.execute_graphql_request(
data,
query,
variables,
operation_name,
show_graphiql
)

if execution_result:
response = {}

if execution_result.errors:
response['errors'] = [self.format_error(e) for e in execution_result.errors]

if execution_result.invalid:
status_code = 400
else:
status_code = 200
response['data'] = execution_result.data

result = self.json_encode(request, response)
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
result = '[{}]'.format(','.join([response[0] for response in responses]))
status_code = max(responses, key=lambda response: response[1])[1]
else:
result = None
result, status_code = self.get_response(request, data, show_graphiql)

if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params(request, data)
return render_graphiql(
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
Expand All @@ -118,6 +102,43 @@ def dispatch_request(self):
content_type='application/json'
)

def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data)

execution_result = self.execute_graphql_request(
data,
query,
variables,
operation_name,
show_graphiql
)

status_code = 200
if execution_result:
response = {}

if execution_result.errors:
response['errors'] = [self.format_error(e) for e in execution_result.errors]

if execution_result.invalid:
status_code = 400
else:
status_code = 200
response['data'] = execution_result.data

if self.batch:
response = {
'id': id,
'payload': response,
'status': status_code,
}

result = self.json_encode(request, response)
else:
result = None

return result, status_code

def json_encode(self, request, d):
if not self.pretty and not request.args.get('pretty'):
return json.dumps(d, separators=(',', ':'))
Expand All @@ -134,7 +155,10 @@ def parse_body(self, request):
elif content_type == 'application/json':
try:
request_json = json.loads(request.data.decode('utf8'))
assert isinstance(request_json, dict)
if self.batch:
assert isinstance(request_json, list)
else:
assert isinstance(request_json, dict)
return request_json
except:
raise HttpError(BadRequest('POST body sent invalid JSON.'))
Expand Down Expand Up @@ -207,6 +231,7 @@ def request_wants_html(cls, request):
def get_graphql_params(request, data):
query = request.args.get('query') or data.get('query')
variables = request.args.get('variables') or data.get('variables')
id = request.args.get('id') or data.get('id')

if variables and isinstance(variables, six.text_type):
try:
Expand All @@ -216,7 +241,7 @@ def get_graphql_params(request, data):

operation_name = request.args.get('operationName') or data.get('operationName')

return query, variables, operation_name
return query, variables, operation_name, id

@staticmethod
def format_error(error):
Expand Down
4 changes: 2 additions & 2 deletions tests/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from .schema import Schema


def create_app(**kwargs):
def create_app(path='/graphql', **kwargs):
app = Flask(__name__)
app.debug = True
app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=Schema, **kwargs))
app.add_url_rule(path, view_func=GraphQLView.as_view('graphql', schema=Schema, **kwargs))
return app


Expand Down
70 changes: 70 additions & 0 deletions tests/test_graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def response_json(response):


j = lambda **kwargs: json.dumps(kwargs)
jl = lambda **kwargs: json.dumps([kwargs])


def test_allows_get_with_query_param(client):
response = client.get(url_string(query='{test}'))
Expand Down Expand Up @@ -453,3 +455,71 @@ def test_post_multipart_data(client):

assert response.status_code == 200
assert response_json(response) == {'data': {u'writeTest': {u'test': u'Hello World'}}}


@pytest.mark.parametrize('app', [create_app(batch=True)])
def test_batch_allows_post_with_json_encoding(client):
response = client.post(
url_string(),
data=jl(id=1, query='{test}'),
content_type='application/json'
)

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': { 'data': {'test': "Hello World"} },
'status': 200,
}]


@pytest.mark.parametrize('app', [create_app(batch=True)])
def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(
url_string(),
data=jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables={'who': "Dolly"}
),
content_type='application/json'
)

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': { 'data': {'test': "Hello Dolly"} },
'status': 200,
}]


@pytest.mark.parametrize('app', [create_app(batch=True)])
def test_batch_allows_post_with_operation_name(client):
response = client.post(
url_string(),
data=jl(
id=1,
query='''
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
),
content_type='application/json'
)

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
},
'status': 200,
}]