Skip to content

Add support for custom scalar types #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from graphql.validation import validate

from .transport.local_schema import LocalSchemaTransport
from .type_adapter import TypeAdapter
from .exceptions import GQLServerError, GQLSyntaxError

log = logging.getLogger(__name__)

Expand All @@ -17,7 +19,10 @@ def __init__(self, retries_count, last_exception):

class Client(object):
def __init__(self, schema=None, introspection=None, type_def=None, transport=None,
fetch_schema_from_transport=False, retries=0):
fetch_schema_from_transport=False, retries=0, custom_types={}):
"""custom_types should be of type Dict[str, Any]
where str is the name of the custom scalar type, and
Any is a class which has a `parse_value()` function"""
assert not(type_def and introspection), 'Cant provide introspection type definition at the same time'
if transport and fetch_schema_from_transport:
assert not schema, 'Cant fetch the schema from transport if is already provided'
Expand All @@ -36,10 +41,11 @@ def __init__(self, schema=None, introspection=None, type_def=None, transport=Non
self.introspection = introspection
self.transport = transport
self.retries = retries
self.type_adapter = TypeAdapter(schema, custom_types) if custom_types else None

def validate(self, document):
if not self.schema:
raise Exception("Cannot validate locally the document, you need to pass a schema.")
raise GQLSyntaxError("Cannot validate locally the document, you need to pass a schema.")
validation_errors = validate(self.schema, document)
if validation_errors:
raise validation_errors[0]
Expand All @@ -50,7 +56,10 @@ def execute(self, document, *args, **kwargs):

result = self._get_result(document, *args, **kwargs)
if result.errors:
raise Exception(str(result.errors[0]))
raise GQLServerError(result.errors[0])

if self.type_adapter:
result.data = self.type_adapter.convert_scalars(result.data)

return result.data

Expand Down
5 changes: 5 additions & 0 deletions gql/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class GQLSyntaxError(Exception):
"""A problem with the GQL query or schema syntax"""

class GQLServerError(Exception):
"""Errors which should be explicitly handled by the calling code"""
117 changes: 117 additions & 0 deletions gql/type_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from graphql.type.definition import GraphQLObjectType, GraphQLScalarType


class TypeAdapter(object):
"""Substitute custom scalars in a GQL response with their decoded counterparts.

GQL custom scalar types are defined on the GQL schema and are used to represent
fields which have special behaviour. To define custom scalar type, you need
the type name, and a class which has a class method called `parse_value()` -
this is the function which will be used to deserialize the custom scalar field.

We first need iterate over all the fields in the response (which is done in
the `_traverse()` function).

Each time we find a field which is a custom scalar (it's type name appears
as a key in self.custom_types), we replace the value of that field with the
decoded value. All of this logic happens in `_substitute()`.

Public Interface:
apply(): pass in a GQL response to replace all instances of custom
scalar strings with their deserialized representation."""

def __init__(self, schema, custom_types = {}):
""" schema: a graphQL schema in the GraphQLSchema format
custom_types: a Dict[str, Any],
where str is the name of the custom scalar type, and
Any is a class which has a `parse_value(str)` function"""
self.schema = schema
self.custom_types = custom_types

def _follow_type_chain(self, node):
""" Get the type of the schema node in question.

In the GraphQL schema, GraphQLFields have a "type" property. However, often
that dict has an "of_type" property itself. In order to get to the actual
type, we need to indefinitely follow the chain of "of_type" fields to get
to the last one, which is the one we care about."""
if isinstance(node, GraphQLObjectType):
return node

field_type = node.type
while hasattr(field_type, 'of_type'):
field_type = field_type.of_type

return field_type

def _get_scalar_type_name(self, field):
"""Returns the name of the type if the type is a scalar type.
Returns None otherwise"""
node = self._follow_type_chain(field)
if isinstance(node, GraphQLScalarType):
return node.name
return None

def _lookup_scalar_type(self, keys):
"""Search through the GQL schema and return the type identified by 'keys'.

If keys (e.g. ['film', 'release_date']) points to a scalar type, then
this function returns the name of that type. (e.g. 'DateTime')

If it is not a scalar type (e..g a GraphQLObject), then this
function returns None.

`keys` is a breadcrumb trail telling us where to look in the GraphQL schema.
By default the root level is `schema.query`, if that fails, then we check
`schema.mutation`."""

def traverse_schema(node, lookup):
if not lookup:
return self._get_scalar_type_name(node)

final_node = self._follow_type_chain(node)
return traverse_schema(final_node.fields[lookup[0]], lookup[1:])

if keys[0] in self.schema.get_query_type().fields:
schema_root = self.schema.get_query_type()
elif keys[0] in self.schema.get_mutation_type().fields:
schema_root = self.schema.get_mutation_type()
else:
return None

try:
return traverse_schema(schema_root, keys)
except (KeyError, AttributeError):
return None

def _get_decoded_scalar_type(self, keys, value):
"""Get the decoded value of the type identified by `keys`.

If the type is not a custom scalar, then return the original value.

If it is a custom scalar, return the deserialized value, as
output by `<CustomScalarType>.parse_value()`"""
scalar_type = self._lookup_scalar_type(keys)
if scalar_type and scalar_type in self.custom_types:
return self.custom_types[scalar_type].parse_value(value)
return value

def convert_scalars(self, response):
"""Recursively traverse the GQL response

Recursively traverses the GQL response and calls _get_decoded_scalar_type()
for all leaf nodes. The function is called with 2 arguments:
keys: List[str] is a breadcrumb trail telling us where we are in the
response, and therefore, where to look in the GQL Schema.
value: Any is the value at that node in the response

Builds a new tree with the substituted values so old `response` is not
modified."""
def iterate(node, keys = []):
if isinstance(node, dict):
return {_key: iterate(value, keys + [_key]) for _key, value in node.items()}
elif isinstance(node, list):
return [(iterate(item, keys)) for item in node]
else:
return self._get_decoded_scalar_type(keys, node)
return iterate(response)
118 changes: 118 additions & 0 deletions tests/test_type_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Tests for the GraphQL Response Parser.

At the moment we use the Star Wars schema which is fetched each time from the
server endpoint. In future it would be better to store this schema in a file
locally.
"""
import copy
from gql.type_adapter import TypeAdapter
import pytest
import requests
from gql import Client
from gql.transport.requests import RequestsHTTPTransport

class Capitalize():
@classmethod
def parse_value(self, value: str):
return value.upper();

@pytest.fixture(scope='session')
def schema():
request = requests.get('http://swapi.graphene-python.org/graphql',
headers={
'Host': 'swapi.graphene-python.org',
'Accept': 'text/html',
})
request.raise_for_status()
csrf = request.cookies['csrftoken']

client = Client(
transport=RequestsHTTPTransport(url='http://swapi.graphene-python.org/graphql',
cookies={"csrftoken": csrf},
headers={'x-csrftoken': csrf}),
fetch_schema_from_transport=True
)

return client.schema

def test_scalar_type_name_for_scalar_field_returns_name(schema):
type_adapter = TypeAdapter(schema)
schema_obj = schema.get_query_type().fields['film']

assert type_adapter ._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime'


def test_scalar_type_name_for_non_scalar_field_returns_none(schema):
type_adapter = TypeAdapter(schema)
schema_obj = schema.get_query_type().fields['film']

assert type_adapter._get_scalar_type_name(schema_obj.type.fields['species']) is None

def test_lookup_scalar_type(schema):
type_adapter = TypeAdapter(schema)

assert type_adapter._lookup_scalar_type(["film"]) is None
assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime'
assert type_adapter._lookup_scalar_type(["film", "species"]) is None

def test_lookup_scalar_type_in_mutation(schema):
type_adapter = TypeAdapter(schema)

assert type_adapter._lookup_scalar_type(["createHero"]) is None
assert type_adapter._lookup_scalar_type(["createHero", "hero"]) is None
assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == 'Boolean'

def test_parse_response(schema):
custom_types = {
'DateTime': Capitalize
}
type_adapter = TypeAdapter(schema, custom_types)

response = {
'film': {
'id': 'some_id',
'releaseDate': 'some_datetime',
}
}

expected = {
'film': {
'id': 'some_id',
'releaseDate': 'SOME_DATETIME',
}
}

assert type_adapter.convert_scalars(response) == expected
assert response['film']['releaseDate'] == 'some_datetime' # ensure original response is not changed

def test_parse_response_containing_list(schema):
custom_types = {
'DateTime': Capitalize
}
type_adapter = TypeAdapter(schema, custom_types)

response = {
"allFilms": {
"edges": [{
"node": {
'id': 'some_id',
'releaseDate': 'some_datetime',
}
},{
"node": {
'id': 'some_id',
'releaseDate': 'some_other_datetime',
}
}]
}
}

expected = copy.deepcopy(response)
expected['allFilms']['edges'][0]['node']['releaseDate'] = 'SOME_DATETIME'
expected['allFilms']['edges'][1]['node']['releaseDate'] = 'SOME_OTHER_DATETIME'

result = type_adapter.convert_scalars(response)
assert result == expected

assert response['allFilms']['edges'][0]['node']['releaseDate'] == 'some_datetime' # ensure original response is not changed
assert response['allFilms']['edges'][1]['node']['releaseDate'] == 'some_other_datetime' # ensure original response is not changed