diff --git a/gql/client.py b/gql/client.py index dcfeb7af..8ca630f3 100644 --- a/gql/client.py +++ b/gql/client.py @@ -4,6 +4,8 @@ from graphql.validation import validate from .transport.local_schema import LocalSchemaTransport +from .type_adapter import TypeAdapter +from .exceptions import GQLServerError, GQLSyntaxError log = logging.getLogger(__name__) @@ -17,7 +19,10 @@ def __init__(self, retries_count, last_exception): class Client(object): def __init__(self, schema=None, introspection=None, type_def=None, transport=None, - fetch_schema_from_transport=False, retries=0): + fetch_schema_from_transport=False, retries=0, custom_types={}): + """custom_types should be of type Dict[str, Any] + where str is the name of the custom scalar type, and + Any is a class which has a `parse_value()` function""" assert not(type_def and introspection), 'Cant provide introspection type definition at the same time' if transport and fetch_schema_from_transport: assert not schema, 'Cant fetch the schema from transport if is already provided' @@ -36,10 +41,11 @@ def __init__(self, schema=None, introspection=None, type_def=None, transport=Non self.introspection = introspection self.transport = transport self.retries = retries + self.type_adapter = TypeAdapter(schema, custom_types) if custom_types else None def validate(self, document): if not self.schema: - raise Exception("Cannot validate locally the document, you need to pass a schema.") + raise GQLSyntaxError("Cannot validate locally the document, you need to pass a schema.") validation_errors = validate(self.schema, document) if validation_errors: raise validation_errors[0] @@ -50,7 +56,10 @@ def execute(self, document, *args, **kwargs): result = self._get_result(document, *args, **kwargs) if result.errors: - raise Exception(str(result.errors[0])) + raise GQLServerError(result.errors[0]) + + if self.type_adapter: + result.data = self.type_adapter.convert_scalars(result.data) return result.data diff --git a/gql/exceptions.py b/gql/exceptions.py new file mode 100644 index 00000000..5a45fb63 --- /dev/null +++ b/gql/exceptions.py @@ -0,0 +1,5 @@ +class GQLSyntaxError(Exception): + """A problem with the GQL query or schema syntax""" + +class GQLServerError(Exception): + """Errors which should be explicitly handled by the calling code""" diff --git a/gql/type_adapter.py b/gql/type_adapter.py new file mode 100644 index 00000000..37566331 --- /dev/null +++ b/gql/type_adapter.py @@ -0,0 +1,117 @@ +from graphql.type.definition import GraphQLObjectType, GraphQLScalarType + + +class TypeAdapter(object): + """Substitute custom scalars in a GQL response with their decoded counterparts. + + GQL custom scalar types are defined on the GQL schema and are used to represent + fields which have special behaviour. To define custom scalar type, you need + the type name, and a class which has a class method called `parse_value()` - + this is the function which will be used to deserialize the custom scalar field. + + We first need iterate over all the fields in the response (which is done in + the `_traverse()` function). + + Each time we find a field which is a custom scalar (it's type name appears + as a key in self.custom_types), we replace the value of that field with the + decoded value. All of this logic happens in `_substitute()`. + + Public Interface: + apply(): pass in a GQL response to replace all instances of custom + scalar strings with their deserialized representation.""" + + def __init__(self, schema, custom_types = {}): + """ schema: a graphQL schema in the GraphQLSchema format + custom_types: a Dict[str, Any], + where str is the name of the custom scalar type, and + Any is a class which has a `parse_value(str)` function""" + self.schema = schema + self.custom_types = custom_types + + def _follow_type_chain(self, node): + """ Get the type of the schema node in question. + + In the GraphQL schema, GraphQLFields have a "type" property. However, often + that dict has an "of_type" property itself. In order to get to the actual + type, we need to indefinitely follow the chain of "of_type" fields to get + to the last one, which is the one we care about.""" + if isinstance(node, GraphQLObjectType): + return node + + field_type = node.type + while hasattr(field_type, 'of_type'): + field_type = field_type.of_type + + return field_type + + def _get_scalar_type_name(self, field): + """Returns the name of the type if the type is a scalar type. + Returns None otherwise""" + node = self._follow_type_chain(field) + if isinstance(node, GraphQLScalarType): + return node.name + return None + + def _lookup_scalar_type(self, keys): + """Search through the GQL schema and return the type identified by 'keys'. + + If keys (e.g. ['film', 'release_date']) points to a scalar type, then + this function returns the name of that type. (e.g. 'DateTime') + + If it is not a scalar type (e..g a GraphQLObject), then this + function returns None. + + `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. + By default the root level is `schema.query`, if that fails, then we check + `schema.mutation`.""" + + def traverse_schema(node, lookup): + if not lookup: + return self._get_scalar_type_name(node) + + final_node = self._follow_type_chain(node) + return traverse_schema(final_node.fields[lookup[0]], lookup[1:]) + + if keys[0] in self.schema.get_query_type().fields: + schema_root = self.schema.get_query_type() + elif keys[0] in self.schema.get_mutation_type().fields: + schema_root = self.schema.get_mutation_type() + else: + return None + + try: + return traverse_schema(schema_root, keys) + except (KeyError, AttributeError): + return None + + def _get_decoded_scalar_type(self, keys, value): + """Get the decoded value of the type identified by `keys`. + + If the type is not a custom scalar, then return the original value. + + If it is a custom scalar, return the deserialized value, as + output by `.parse_value()`""" + scalar_type = self._lookup_scalar_type(keys) + if scalar_type and scalar_type in self.custom_types: + return self.custom_types[scalar_type].parse_value(value) + return value + + def convert_scalars(self, response): + """Recursively traverse the GQL response + + Recursively traverses the GQL response and calls _get_decoded_scalar_type() + for all leaf nodes. The function is called with 2 arguments: + keys: List[str] is a breadcrumb trail telling us where we are in the + response, and therefore, where to look in the GQL Schema. + value: Any is the value at that node in the response + + Builds a new tree with the substituted values so old `response` is not + modified.""" + def iterate(node, keys = []): + if isinstance(node, dict): + return {_key: iterate(value, keys + [_key]) for _key, value in node.items()} + elif isinstance(node, list): + return [(iterate(item, keys)) for item in node] + else: + return self._get_decoded_scalar_type(keys, node) + return iterate(response) diff --git a/tests/test_type_adapter.py b/tests/test_type_adapter.py new file mode 100644 index 00000000..49ce8b2e --- /dev/null +++ b/tests/test_type_adapter.py @@ -0,0 +1,118 @@ +"""Tests for the GraphQL Response Parser. + +At the moment we use the Star Wars schema which is fetched each time from the +server endpoint. In future it would be better to store this schema in a file +locally. +""" +import copy +from gql.type_adapter import TypeAdapter +import pytest +import requests +from gql import Client +from gql.transport.requests import RequestsHTTPTransport + +class Capitalize(): + @classmethod + def parse_value(self, value: str): + return value.upper(); + +@pytest.fixture(scope='session') +def schema(): + request = requests.get('http://swapi.graphene-python.org/graphql', + headers={ + 'Host': 'swapi.graphene-python.org', + 'Accept': 'text/html', + }) + request.raise_for_status() + csrf = request.cookies['csrftoken'] + + client = Client( + transport=RequestsHTTPTransport(url='http://swapi.graphene-python.org/graphql', + cookies={"csrftoken": csrf}, + headers={'x-csrftoken': csrf}), + fetch_schema_from_transport=True + ) + + return client.schema + +def test_scalar_type_name_for_scalar_field_returns_name(schema): + type_adapter = TypeAdapter(schema) + schema_obj = schema.get_query_type().fields['film'] + + assert type_adapter ._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' + + +def test_scalar_type_name_for_non_scalar_field_returns_none(schema): + type_adapter = TypeAdapter(schema) + schema_obj = schema.get_query_type().fields['film'] + + assert type_adapter._get_scalar_type_name(schema_obj.type.fields['species']) is None + +def test_lookup_scalar_type(schema): + type_adapter = TypeAdapter(schema) + + assert type_adapter._lookup_scalar_type(["film"]) is None + assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' + assert type_adapter._lookup_scalar_type(["film", "species"]) is None + +def test_lookup_scalar_type_in_mutation(schema): + type_adapter = TypeAdapter(schema) + + assert type_adapter._lookup_scalar_type(["createHero"]) is None + assert type_adapter._lookup_scalar_type(["createHero", "hero"]) is None + assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' + +def test_parse_response(schema): + custom_types = { + 'DateTime': Capitalize + } + type_adapter = TypeAdapter(schema, custom_types) + + response = { + 'film': { + 'id': 'some_id', + 'releaseDate': 'some_datetime', + } + } + + expected = { + 'film': { + 'id': 'some_id', + 'releaseDate': 'SOME_DATETIME', + } + } + + assert type_adapter.convert_scalars(response) == expected + assert response['film']['releaseDate'] == 'some_datetime' # ensure original response is not changed + +def test_parse_response_containing_list(schema): + custom_types = { + 'DateTime': Capitalize + } + type_adapter = TypeAdapter(schema, custom_types) + + response = { + "allFilms": { + "edges": [{ + "node": { + 'id': 'some_id', + 'releaseDate': 'some_datetime', + } + },{ + "node": { + 'id': 'some_id', + 'releaseDate': 'some_other_datetime', + } + }] + } + } + + expected = copy.deepcopy(response) + expected['allFilms']['edges'][0]['node']['releaseDate'] = 'SOME_DATETIME' + expected['allFilms']['edges'][1]['node']['releaseDate'] = 'SOME_OTHER_DATETIME' + + result = type_adapter.convert_scalars(response) + assert result == expected + + assert response['allFilms']['edges'][0]['node']['releaseDate'] == 'some_datetime' # ensure original response is not changed + assert response['allFilms']['edges'][1]['node']['releaseDate'] == 'some_other_datetime' # ensure original response is not changed