From 97cd612bdd89c45e400ffaa1afc1697c344e653e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 12 Nov 2021 15:16:03 +0100 Subject: [PATCH 01/21] Implementation of serialization of variable values --- gql/client.py | 199 ++++++- gql/variable_values.py | 117 ++++ tests/custom_scalars/__init__.py | 0 .../test_custom_scalar_datetime.py | 174 ++++++ .../test_custom_scalar_money.py | 527 ++++++++++++++++++ 5 files changed, 1005 insertions(+), 12 deletions(-) create mode 100644 gql/variable_values.py create mode 100644 tests/custom_scalars/__init__.py create mode 100644 tests/custom_scalars/test_custom_scalar_datetime.py create mode 100644 tests/custom_scalars/test_custom_scalar_money.py diff --git a/gql/client.py b/gql/client.py index 6017ab69..368193cc 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,6 +17,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .variable_values import serialize_variable_values class Client: @@ -289,18 +290,79 @@ def __init__(self, client: Client): """:param client: the :class:`client ` used""" self.client = client - def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def _execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> ExecutionResult: + """Execute the provided document AST synchronously using + the sync transport, returning an ExecutionResult object. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) - return self.transport.execute(document, *args, **kwargs) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + + return self.transport.execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: + """Execute the provided document AST synchronously using + the sync transport. + + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = self._execute(document, *args, **kwargs) + result = self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: @@ -341,17 +403,52 @@ def __init__(self, client: Client): self.client = client async def _subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: + """Coroutine to subscribe asynchronously to the provided document AST + asynchronously using the async transport, + returning an async generator producing ExecutionResult objects. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport subscribe method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Subscribe to the transport inner_generator: AsyncGenerator[ ExecutionResult, None - ] = self.transport.subscribe(document, *args, **kwargs) + ] = self.transport.subscribe( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) # Keep a reference to the inner generator to allow the user to call aclose() # before a break if python version is too old (pypy3 py 3.6.1) @@ -364,15 +461,35 @@ async def _subscribe( await inner_generator.aclose() async def subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[Dict, None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport subscribe method.""" inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, *args, **kwargs + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, ) try: @@ -391,27 +508,85 @@ async def subscribe( await inner_generator.aclose() async def _execute( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> ExecutionResult: + """Coroutine to execute the provided document AST asynchronously using + the async transport, returning an ExecutionResult object. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Execute the query with the transport with a timeout return await asyncio.wait_for( - self.transport.execute(document, *args, **kwargs), + self.transport.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + *args, + **kwargs, + ), self.client.execute_timeout, ) - async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + async def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: """Coroutine to execute the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = await self._execute(document, *args, **kwargs) + result = await self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: diff --git a/gql/variable_values.py b/gql/variable_values.py new file mode 100644 index 00000000..7db7091a --- /dev/null +++ b/gql/variable_values.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, Optional + +from graphql import ( + DocumentNode, + GraphQLEnumType, + GraphQLError, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + GraphQLScalarType, + GraphQLSchema, + GraphQLType, + GraphQLWrappingType, + OperationDefinitionNode, + type_from_ast, +) +from graphql.pyutils import inspect + + +def get_document_operation( + document: DocumentNode, operation_name: Optional[str] = None +) -> OperationDefinitionNode: + """Returns the operation which should be executed in the document. + + Raises a GraphQLError if a single operation cannot be retrieved. + """ + + operation: Optional[OperationDefinitionNode] = None + + for definition in document.definitions: + if isinstance(definition, OperationDefinitionNode): + if operation_name is None: + if operation: + raise GraphQLError( + "Must provide operation name" + " if query contains multiple operations." + ) + operation = definition + elif definition.name and definition.name.value == operation_name: + operation = definition + + if not operation: + if operation_name is not None: + raise GraphQLError(f"Unknown operation named '{operation_name}'.") + + # The following line should never happen normally as the document is + # already verified before calling this function. + raise GraphQLError("Must provide an operation.") # pragma: no cover + + return operation + + +def serialize_value(type_: GraphQLType, value: Any) -> Any: + """Given a GraphQL type and a Python value, return the serialized value. + + Can be used to serialize Enums and/or Custom Scalars in variable values. + """ + + if value is None: + if isinstance(type_, GraphQLNonNull): + # raise GraphQLError(f"Type {type_.of_type.name} Cannot be None.") + raise GraphQLError(f"Type {inspect(type_)} Cannot be None.") + else: + return None + + if isinstance(type_, GraphQLWrappingType): + inner_type = type_.of_type + + if isinstance(type_, GraphQLNonNull): + return serialize_value(inner_type, value) + + elif isinstance(type_, GraphQLList): + return [serialize_value(inner_type, v) for v in value] + + elif isinstance(type_, (GraphQLScalarType, GraphQLEnumType)): + return type_.serialize(value) + + elif isinstance(type_, GraphQLInputObjectType): + return { + field_name: serialize_value(field.type, value[field_name]) + for field_name, field in type_.fields.items() + } + + raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.") + + +def serialize_variable_values( + schema: GraphQLSchema, + document: DocumentNode, + variable_values: Dict[str, Any], + operation_name: Optional[str] = None, +) -> Dict[str, Any]: + """Given a GraphQL document and a schema, serialize the Dictionary of + variable values. + + Useful to serialize Enums and/or Custom Scalars in variable values + """ + + parsed_variable_values: Dict[str, Any] = {} + + # Find the operation in the document + operation = get_document_operation(document, operation_name=operation_name) + + # Serialize every variable value defined for the operation + for var_def_node in operation.variable_definitions: + var_name = var_def_node.variable.name.value + var_type = type_from_ast(schema, var_def_node.type) + + if var_name in variable_values: + + assert var_type is not None + + var_value = variable_values[var_name] + + parsed_variable_values[var_name] = serialize_value(var_type, var_value) + + return parsed_variable_values diff --git a/tests/custom_scalars/__init__.py b/tests/custom_scalars/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py new file mode 100644 index 00000000..22b3987a --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -0,0 +1,174 @@ +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +import pytest +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInt, + GraphQLList, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +def serialize_datetime(value: Any) -> str: + if not isinstance(value, datetime): + raise GraphQLError("Cannot serialize datetime value: " + inspect(value)) + return value.isoformat() + + +def parse_datetime_value(value: Any) -> datetime: + + if isinstance(value, str): + try: + # Note: a more solid custom scalar should use dateutil.parser.isoparse + # Not using it here in the test to avoid adding another dependency + return datetime.fromisoformat(value) + except Exception: + raise GraphQLError("Cannot parse datetime value : " + inspect(value)) + + else: + raise GraphQLError("Cannot parseee datetime value: " + inspect(value)) + + +def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + if not isinstance(ast_value, str): + raise GraphQLError("Cannot parse literal datetime value: " + inspect(ast_value)) + + return parse_datetime_value(ast_value) + + +DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, +) + + +def resolve_shift_days(root, _info, time, days): + return time + timedelta(days=days) + + +def resolve_latest(root, _info, times): + return max(times) + + +def resolve_seconds(root, _info, interval): + print(f"interval={interval!r}") + return (interval["end"] - interval["start"]).total_seconds() + + +IntervalInputType = GraphQLInputObjectType( + "IntervalInput", + fields={ + "start": GraphQLInputField(DatetimeScalar), + "end": GraphQLInputField(DatetimeScalar), + }, +) + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "shiftDays": GraphQLField( + DatetimeScalar, + args={ + "time": GraphQLArgument(DatetimeScalar), + "days": GraphQLArgument(GraphQLInt), + }, + resolve=resolve_shift_days, + ), + "latest": GraphQLField( + DatetimeScalar, + args={"times": GraphQLArgument(GraphQLList(DatetimeScalar))}, + resolve=resolve_latest, + ), + "seconds": GraphQLField( + GraphQLInt, + args={"interval": GraphQLArgument(IntervalInputType)}, + resolve=resolve_seconds, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +def test_shift_days(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": now, + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +def test_latest(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql("query latest($times: [Datetime!]!) {latest(times: $times)}") + + variable_values = { + "times": [now, in_five_days], + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["latest"] == in_five_days.isoformat() + + +def test_seconds(): + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql( + "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" + ) + + variable_values = {"interval": {"start": now, "end": in_five_days}} + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["seconds"] == 432000 diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py new file mode 100644 index 00000000..6560a3eb --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -0,0 +1,527 @@ +import asyncio +from typing import Any, Dict, NamedTuple, Optional + +import pytest +from graphql import graphql_sync +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect, is_finite +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.variable_values import serialize_value + +from ..conftest import MS + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +class Money(NamedTuple): + amount: float + currency: str + + +def serialize_money(output_value: Any) -> Dict[str, Any]: + if not isinstance(output_value, Money): + raise GraphQLError("Cannot serialize money value: " + inspect(output_value)) + return output_value._asdict() + + +def parse_money_value(input_value: Any) -> Money: + """Using Money custom scalar from graphql-core tests except here the + input value is supposed to be a dict instead of a Money object.""" + + """ + if isinstance(input_value, Money): + return input_value + """ + + if isinstance(input_value, dict): + amount = input_value.get("amount", None) + currency = input_value.get("currency", None) + + if not is_finite(amount) or not isinstance(currency, str): + raise GraphQLError("Cannot parse money value dict: " + inspect(input_value)) + + return Money(float(amount), currency) + else: + raise GraphQLError("Cannot parse money value: " + inspect(input_value)) + + +def parse_money_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Money: + money = value_from_ast_untyped(value_node, variables) + if variables is not None and ( + # variables are not set when checked with ValuesIOfCorrectTypeRule + not money + or not is_finite(money.get("amount")) + or not isinstance(money.get("currency"), str) + ): + raise GraphQLError("Cannot parse literal money value: " + inspect(money)) + return Money(**money) + + +MoneyScalar = GraphQLScalarType( + name="Money", + serialize=serialize_money, + parse_value=parse_money_value, + parse_literal=parse_money_literal, +) + + +def resolve_balance(root, _info): + return root + + +def resolve_to_euros(_root, _info, money): + amount = money.amount + currency = money.currency + if not amount or currency == "EUR": + return amount + if currency == "DM": + return amount * 0.5 + raise ValueError("Cannot convert to euros: " + inspect(money)) + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "balance": GraphQLField(MoneyScalar, resolve=resolve_balance), + "toEuros": GraphQLField( + GraphQLFloat, + args={"money": GraphQLArgument(MoneyScalar)}, + resolve=resolve_to_euros, + ), + }, +) + + +def resolve_spent_money(spent_money, _info, **kwargs): + return spent_money + + +async def subscribe_spend_all(_root, _info, money): + while money.amount > 0: + money = Money(money.amount - 1, money.currency) + yield money + await asyncio.sleep(1 * MS) + + +subscriptionType = GraphQLObjectType( + "Subscription", + fields=lambda: { + "spend": GraphQLField( + MoneyScalar, + args={"money": GraphQLArgument(MoneyScalar)}, + subscribe=subscribe_spend_all, + resolve=resolve_spent_money, + ) + }, +) + +root_value = Money(42, "DM") + +schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) + + +def test_custom_scalar_in_output(): + + client = Client(schema=schema) + + query = gql("{balance}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +def test_custom_scalar_in_input_query(): + + client = Client(schema=schema) + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 10 + + +def test_custom_scalar_in_input_variable_values(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + + variable_values = {"money": money_value} + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized_with_operation_name(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="myquery", + ) + + assert result["toEuros"] == 5 + + +def test_serialize_variable_values_exception_multiple_ops_without_operation_name(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } + + query mybalance { + balance + }""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + exception = exc_info.value + + assert ( + str(exception) + == "Must provide operation name if query contains multiple operations." + ) + + +def test_serialize_variable_values_exception_operation_name_not_found(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } +""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="invalid_operation_name", + ) + + exception = exc_info.value + + assert str(exception) == "Unknown operation named 'invalid_operation_name'." + + +def test_custom_scalar_subscribe_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("subscription spendAll($money: Money) {spend(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + expected_result = {"spend": {"amount": 10, "currency": "DM"}} + + for result in client.subscribe( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ): + print(f"result = {result!r}") + expected_result["spend"]["amount"] = expected_result["spend"]["amount"] - 1 + assert expected_result == result + + +async def make_money_backend(aiohttp_server): + from aiohttp import web + + async def handler(request): + data = await request.json() + source = data["query"] + + print(f"data keys = {data.keys()}") + try: + variables = data["variables"] + print(f"variables = {variables!r}") + except KeyError: + variables = None + + result = graphql_sync( + schema, source, variable_values=variables, root_value=root_value + ) + + print(f"backend result = {result!r}") + + return web.json_response( + { + "data": result.data, + "errors": [str(e) for e in result.errors] if result.errors else None, + } + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + return server + + +async def make_money_transport(aiohttp_server): + from gql.transport.aiohttp import AIOHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + return transport + + +async def make_sync_money_transport(aiohttp_server): + from gql.transport.requests import RequestsHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = RequestsHTTPTransport(url=url, timeout=10) + + return (server, transport) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("{balance}") + + result = await session.execute(query) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 10 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + # money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_split_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql( + """ +query myquery($amount: Float, $currency: String) { + toEuros(money: {amount: $amount, currency: $currency}) +}""" + ) + + variable_values = {"amount": 10, "currency": "DM"} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +@pytest.mark.requests +async def test_custom_scalar_serialize_variables_sync_transport( + event_loop, aiohttp_server, run_sync_test +): + + server, transport = await make_sync_money_transport(aiohttp_server) + + def test_code(): + with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + await run_sync_test(event_loop, server, test_code) + + +def test_serialize_value_with_invalid_type(): + + with pytest.raises(GraphQLError) as exc_info: + serialize_value("Not a valid type", 50) + + exception = exc_info.value + + assert ( + str(exception) == "Impossible to serialize value with type: 'Not a valid type'." + ) + + +def test_serialize_value_with_non_null_type_null(): + + non_null_int = GraphQLNonNull(GraphQLInt) + + with pytest.raises(GraphQLError) as exc_info: + serialize_value(non_null_int, None) + + exception = exc_info.value + + assert str(exception) == "Type Int! Cannot be None." + + +def test_serialize_value_with_nullable_type(): + + nullable_int = GraphQLInt + + assert serialize_value(nullable_int, None) is None From 2b3db148fbe02d4b04c3f00f571c1d27bb7feb23 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 12 Nov 2021 15:24:51 +0100 Subject: [PATCH 02/21] Skip datetime tests on Python 3.6 --- tests/custom_scalars/test_custom_scalar_datetime.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py index 22b3987a..058b269e 100644 --- a/tests/custom_scalars/test_custom_scalar_datetime.py +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -20,9 +20,6 @@ from gql import Client, gql -# Marking all tests in this file with the aiohttp marker -pytestmark = pytest.mark.aiohttp - def serialize_datetime(value: Any) -> str: if not isinstance(value, datetime): @@ -110,6 +107,9 @@ def resolve_seconds(root, _info, interval): schema = GraphQLSchema(query=queryType) +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) def test_shift_days(): client = Client(schema=schema) @@ -131,6 +131,9 @@ def test_shift_days(): assert result["shiftDays"] == "2021-11-17T11:58:13.461161" +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) def test_latest(): client = Client(schema=schema) @@ -153,6 +156,9 @@ def test_latest(): assert result["latest"] == in_five_days.isoformat() +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) def test_seconds(): client = Client(schema=schema) From c251a2f9206b4a29b4a8404c08574b7c5c7f2993 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 12 Nov 2021 19:49:14 +0100 Subject: [PATCH 03/21] Add update_schema_scalars function This function allows us to update a schema from a file or from introspection with custom scalars implementations --- gql/utilities/__init__.py | 5 + gql/utilities/update_schema_scalars.py | 32 ++++++ .../test_custom_scalar_money.py | 108 ++++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 gql/utilities/__init__.py create mode 100644 gql/utilities/update_schema_scalars.py diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py new file mode 100644 index 00000000..68b80156 --- /dev/null +++ b/gql/utilities/__init__.py @@ -0,0 +1,5 @@ +from .update_schema_scalars import update_schema_scalars + +__all__ = [ + "update_schema_scalars", +] diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py new file mode 100644 index 00000000..d5434c6b --- /dev/null +++ b/gql/utilities/update_schema_scalars.py @@ -0,0 +1,32 @@ +from typing import Iterable, List + +from graphql import GraphQLError, GraphQLScalarType, GraphQLSchema + + +def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): + """Update the scalars in a schema with the scalars provided. + + This can be used to update the default Custom Scalar implementation + when the schema has been provided from a text file or from introspection. + """ + + if not isinstance(scalars, Iterable): + raise GraphQLError("Scalars argument should be a list of scalars.") + + for scalar in scalars: + if not isinstance(scalar, GraphQLScalarType): + raise GraphQLError("Scalars should be instances of GraphQLScalarType.") + + try: + schema_scalar = schema.type_map[scalar.name] + except KeyError: + raise GraphQLError(f"Scalar '{scalar.name}' not found in schema.") + + assert isinstance(schema_scalar, GraphQLScalarType) + + # Update the conversion methods + # Using setattr because mypy has a false positive + # https://github.com/python/mypy/issues/2427 + setattr(schema_scalar, "serialize", scalar.serialize) + setattr(schema_scalar, "parse_value", scalar.parse_value) + setattr(schema_scalar, "parse_literal", scalar.parse_literal) diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py index 6560a3eb..238308a9 100644 --- a/tests/custom_scalars/test_custom_scalar_money.py +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -19,6 +19,8 @@ from graphql.utilities import value_from_ast_untyped from gql import Client, gql +from gql.transport.exceptions import TransportQueryError +from gql.utilities import update_schema_scalars from gql.variable_values import serialize_value from ..conftest import MS @@ -471,6 +473,112 @@ async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): assert result["toEuros"] == 5 +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + with pytest.raises(TransportQueryError): + await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_schema_from_introspection( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + schema = session.client.schema + + # Updating the Money Scalar in the schema + # We cannot replace it because some other objects keep a reference + # to the existing Scalar + # cannot do: schema.type_map["Money"] = MoneyScalar + + money_scalar = schema.type_map["Money"] + + money_scalar.serialize = MoneyScalar.serialize + money_scalar.parse_value = MoneyScalar.parse_value + money_scalar.parse_literal = MoneyScalar.parse_literal + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_update_schema_scalars(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # Update the schema MoneyScalar default implementation from + # introspection with our provided conversion methods + update_schema_scalars(session.client.schema, [MoneyScalar]) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +def test_update_schema_scalars_invalid_scalar(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [int]) + + exception = exc_info.value + + assert str(exception) == "Scalars should be instances of GraphQLScalarType." + + +def test_update_schema_scalars_invalid_scalar_argument(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, MoneyScalar) + + exception = exc_info.value + + assert str(exception) == "Scalars argument should be a list of scalars." + + +def test_update_schema_scalars_scalar_not_found_in_schema(): + + NotFoundScalar = GraphQLScalarType(name="abcd",) + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) + + exception = exc_info.value + + assert str(exception) == "Scalar 'abcd' not found in schema." + + @pytest.mark.asyncio @pytest.mark.requests async def test_custom_scalar_serialize_variables_sync_transport( From 104788500e8d65e64ca7474d88304368348425ae Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 13 Nov 2021 13:56:47 +0100 Subject: [PATCH 04/21] Adding documentation --- docs/modules/gql.rst | 1 + docs/modules/utilities.rst | 6 + docs/usage/custom_scalars.rst | 134 ++++++++++++++++++ docs/usage/index.rst | 1 + .../test_custom_scalar_datetime.py | 42 +++++- 5 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 docs/modules/utilities.rst create mode 100644 docs/usage/custom_scalars.rst diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index aac47c86..06a89a96 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,3 +21,4 @@ Sub-Packages client transport dsl + utilities diff --git a/docs/modules/utilities.rst b/docs/modules/utilities.rst new file mode 100644 index 00000000..47043b98 --- /dev/null +++ b/docs/modules/utilities.rst @@ -0,0 +1,6 @@ +gql.utilities +============= + +.. currentmodule:: gql.utilities + +.. automodule:: gql.utilities diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst new file mode 100644 index 00000000..baee441e --- /dev/null +++ b/docs/usage/custom_scalars.rst @@ -0,0 +1,134 @@ +Custom Scalars +============== + +Scalar types represent primitive values at the leaves of a query. + +GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend +can add additional custom scalars to its schema to better express values in their data model. + +For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. + +The schema will then only contain: + +.. code-block:: python + + scalar Datetime + +When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), +their values need to be serialized to be composed +of only built-in scalars, then at the destination the serialized values will be parsed again to +be able to represent the scalar in its local internal representation. + +Because this serialization/unserialization is dependent on the language used at both sides, it is not +described in the schema and needs to be defined independently at both sides (client, backend). + +A custom scalar value can have two different representations during its transport: + + - as a serialized value (usually as json): + + * in the results sent by the backend + * in the variables sent by the client alongside the query + + - as "literal" inside the query itself sent by the client + +To define a custom scalar, you need 3 methods: + + - a :code:`serialize` method used: + + * by the backend to serialize a custom scalar output in the result + * by the client to serialize a custom scalar input in the variables + + - a :code:`parse_value` method used: + + * by the backend to unserialize custom scalars inputs in the variables sent by the client + * by the client to unserialize custom scalars outputs from the results + + - a :code:`parse_literal` method used: + + * by the backend to unserialize custom scalars inputs inside the query itself + +To define a custom scalar object, we define a :code:`GraphQLScalarType` from graphql-core with +its name and the implementation of the above methods. + +Example for Datetime: + +.. code-block:: python + + from datetime import datetime + from typing import Any, Dict, Optional + + from graphql import GraphQLScalarType, ValueNode + from graphql.utilities import value_from_ast_untyped + + + def serialize_datetime(value: Any) -> str: + return value.isoformat() + + + def parse_datetime_value(value: Any) -> datetime: + return datetime.fromisoformat(value) + + + def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None + ) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + return parse_datetime_value(ast_value) + + + DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, + ) + +Custom Scalars in inputs +------------------------ + +To provide custom scalars in input with gql, you can: + +- serialize the scalar yourself as "literal" in the query: + +.. code-block:: python + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + +- serialize the scalar yourself in a variable: + +.. code-block:: python + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + +- add a custom scalar to the schema with :func:`update_schema_scalars ` + and execute the query with :code:`serialize_variables=True` + and gql will serialize the variable values from a Python object representation. + +For this, you need to provide a schema or set :code:`fetch_schema_from_transport=True` +in the client to request the schema from the backend. + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) diff --git a/docs/usage/index.rst b/docs/usage/index.rst index a7dd4d56..4a38093a 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -10,3 +10,4 @@ Usage variables headers file_upload + custom_scalars diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py index 058b269e..25c6bb31 100644 --- a/tests/custom_scalars/test_custom_scalar_datetime.py +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -38,7 +38,7 @@ def parse_datetime_value(value: Any) -> datetime: raise GraphQLError("Cannot parse datetime value : " + inspect(value)) else: - raise GraphQLError("Cannot parseee datetime value: " + inspect(value)) + raise GraphQLError("Cannot parse datetime value: " + inspect(value)) def parse_datetime_literal( @@ -131,6 +131,46 @@ def test_shift_days(): assert result["shiftDays"] == "2021-11-17T11:58:13.461161" +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_query(): + + client = Client(schema=schema) + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + + result = client.execute(query) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_variables(): + + client = Client(schema=schema) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + @pytest.mark.skipif( not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" ) From f484ae224d161240d3f689f3a5e055374492cbad Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 15 Nov 2021 17:33:59 +0100 Subject: [PATCH 05/21] dsl_serialize_to_literals --- gql/dsl.py | 102 +++++++- .../custom_scalars/test_custom_scalar_json.py | 241 ++++++++++++++++++ tests/starwars/test_dsl.py | 33 ++- 3 files changed, 368 insertions(+), 8 deletions(-) create mode 100644 tests/custom_scalars/test_custom_scalar_json.py diff --git a/gql/dsl.py b/gql/dsl.py index f3bd1fe2..1646d402 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,15 +1,22 @@ import logging +import re from abc import ABC, abstractmethod +from math import isfinite from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, + BooleanValueNode, DocumentNode, + EnumValueNode, FieldNode, + FloatValueNode, FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLError, GraphQLField, + GraphQLID, GraphQLInputObjectType, GraphQLInputType, GraphQLInterfaceType, @@ -20,6 +27,7 @@ GraphQLSchema, GraphQLWrappingType, InlineFragmentNode, + IntValueNode, ListTypeNode, ListValueNode, NamedTypeNode, @@ -31,25 +39,76 @@ OperationDefinitionNode, OperationType, SelectionSetNode, + StringValueNode, TypeNode, Undefined, ValueNode, VariableDefinitionNode, VariableNode, assert_named_type, + is_enum_type, is_input_object_type, + is_leaf_type, is_list_type, is_non_null_type, is_wrapping_type, print_ast, ) -from graphql.pyutils import FrozenList -from graphql.utilities import ast_from_value as default_ast_from_value +from graphql.pyutils import FrozenList, inspect from .utils import to_camel_case log = logging.getLogger(__name__) +_re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$") + + +def ast_from_serialized_value_untyped(serialized: Any) -> Optional[ValueNode]: + """Given a serialized value, try our best to produce an AST. + + Anything ressembling an array (instance of Mapping) will be converted + to an ObjectFieldNode. + + Anything ressembling a list (instance of Iterable - except str) + will be converted to a ListNode. + + In some cases, a custom scalar can be serialized differently in the query + than in the variables. In that case, this function will not work.""" + + if serialized is None or serialized is Undefined: + return NullValueNode() + + if isinstance(serialized, Mapping): + field_items = ( + (key, ast_from_serialized_value_untyped(value)) + for key, value in serialized.items() + ) + field_nodes = ( + ObjectFieldNode(name=NameNode(value=field_name), value=field_value) + for field_name, field_value in field_items + if field_value + ) + return ObjectValueNode(fields=FrozenList(field_nodes)) + + if isinstance(serialized, Iterable) and not isinstance(serialized, str): + maybe_nodes = (ast_from_serialized_value_untyped(item) for item in serialized) + nodes = filter(None, maybe_nodes) + return ListValueNode(values=FrozenList(nodes)) + + if isinstance(serialized, bool): + return BooleanValueNode(value=serialized) + + if isinstance(serialized, int): + return IntValueNode(value=f"{serialized:d}") + + if isinstance(serialized, float) and isfinite(serialized): + return FloatValueNode(value=f"{serialized:g}") + + if isinstance(serialized, str): + return StringValueNode(value=serialized) + + raise TypeError(f"Cannot convert value to AST: {inspect(serialized)}.") + def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: """ @@ -60,15 +119,21 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: VariableNode when value is a DSLVariable Produce a GraphQL Value AST given a Python object. + + Raises a GraphQLError instead of returning None if we receive an Undefined + of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): return value.set_type(type_).ast_variable if is_non_null_type(type_): type_ = cast(GraphQLNonNull, type_) - ast_value = ast_from_value(value, type_.of_type) + inner_type = type_.of_type + ast_value = ast_from_value(value, inner_type) if isinstance(ast_value, NullValueNode): - return None + raise GraphQLError( + "Received Null value for a Non-Null type " f"{inspect(inner_type)}." + ) return ast_value # only explicit None, not Undefined or NaN @@ -77,7 +142,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: # undefined if value is Undefined: - return None + raise GraphQLError(f"Received Undefined value for type {inspect(type_)}.") # Convert Python list to GraphQL list. If the GraphQLType is a list, but the value # is not a list, convert the value using the list's item type. @@ -108,7 +173,32 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: ) return ObjectValueNode(fields=FrozenList(field_nodes)) - return default_ast_from_value(value, type_) + if is_leaf_type(type_): + # Since value is an internally represented value, it must be serialized to an + # externally represented value before converting into an AST. + serialized = type_.serialize(value) # type: ignore + + # if the serialized value is a string, then we should use the + # type to determine if it is an enum, an ID or a normal string + if isinstance(serialized, str): + # Enum types use Enum literals. + if is_enum_type(type_): + return EnumValueNode(value=serialized) + + # ID types can use Int literals. + if type_ is GraphQLID and _re_integer_string.match(serialized): + return IntValueNode(value=serialized) + + return StringValueNode(value=serialized) + + # Some custom scalars will serialize to dicts or lists + # Providing here a default conversion to AST using our best judgment + # until graphql-js issue #1817 is solved + # https://github.com/graphql/graphql-js/issues/1817 + return ast_from_serialized_value_untyped(serialized) + + # Not reachable. All possible input types have been considered. + raise TypeError(f"Unexpected input type: {inspect(type_)}.") def dsl_gql( diff --git a/tests/custom_scalars/test_custom_scalar_json.py b/tests/custom_scalars/test_custom_scalar_json.py new file mode 100644 index 00000000..80f99850 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_json.py @@ -0,0 +1,241 @@ +from typing import Any, Dict, Optional + +import pytest +from graphql import ( + GraphQLArgument, + GraphQLError, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.language import ValueNode +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.dsl import DSLSchema + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +def serialize_json(value: Any) -> Dict[str, Any]: + return value + + +def parse_json_value(value: Any) -> Any: + return value + + +def parse_json_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Any: + return value_from_ast_untyped(value_node, variables) + + +JsonScalar = GraphQLScalarType( + name="JSON", + serialize=serialize_json, + parse_value=parse_json_value, + parse_literal=parse_json_literal, +) + +root_value = { + "players": [ + { + "name": "John", + "level": 3, + "is_connected": True, + "score": 123.45, + "friends": ["Alex", "Alicia"], + }, + { + "name": "Alex", + "level": 4, + "is_connected": False, + "score": 1337.69, + "friends": None, + }, + ] +} + + +def resolve_players(root, _info): + return root["players"] + + +queryType = GraphQLObjectType( + name="Query", fields={"players": GraphQLField(JsonScalar, resolve=resolve_players)}, +) + + +def resolve_add_player(root, _info, player): + print(f"player = {player!r}") + root["players"].append(player) + return {"players": root["players"]} + + +mutationType = GraphQLObjectType( + name="Mutation", + fields={ + "addPlayer": GraphQLField( + JsonScalar, + args={"player": GraphQLArgument(GraphQLNonNull(JsonScalar))}, + resolve=resolve_add_player, + ) + }, +) + +schema = GraphQLSchema(query=queryType, mutation=mutationType) + + +def test_json_value_output(): + + client = Client(schema=schema) + + query = gql("query {players}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["players"] == serialize_json(root_value["players"]) + + +def test_json_value_input_in_ast(): + + client = Client(schema=schema) + + query = gql( + """ + mutation adding_player { + addPlayer(player: { + name: "Tom", + level: 1, + is_connected: True, + score: 0, + friends: [ + "John" + ] + }) +}""" + ) + + result = client.execute(query, root_value=root_value) + + print(result) + + players = result["addPlayer"]["players"] + + assert players == serialize_json(root_value["players"]) + assert players[-1]["name"] == "Tom" + + +def test_json_value_input_in_ast_with_variables(): + + print(f"{schema.type_map!r}") + client = Client(schema=schema) + + # Note: we need to manually add the built-in types which + # are not present in the schema + schema.type_map["Int"] = GraphQLInt + schema.type_map["Float"] = GraphQLFloat + + query = gql( + """ + mutation adding_player( + $name: String!, + $level: Int!, + $is_connected: Boolean, + $score: Float!, + $friends: [String!]!) { + + addPlayer(player: { + name: $name, + level: $level, + is_connected: $is_connected, + score: $score, + friends: $friends, + }) +}""" + ) + + variable_values = { + "name": "Barbara", + "level": 1, + "is_connected": False, + "score": 69, + "friends": ["Alex", "John"], + } + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + print(result) + + players = result["addPlayer"]["players"] + + assert players == serialize_json(root_value["players"]) + assert players[-1]["name"] == "Barbara" + + +def test_json_value_input_in_dsl_argument(): + + ds = DSLSchema(schema) + + new_player = { + "name": "Tim", + "level": 0, + "is_connected": False, + "score": 5, + "friends": ["Lea"], + } + + query = ds.Mutation.addPlayer(player=new_player) + + print(str(query)) + + assert ( + str(query) + == """addPlayer( + player: {name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"]} +)""" + ) + + +def test_none_json_value_input_in_dsl_argument(): + + ds = DSLSchema(schema) + + with pytest.raises(GraphQLError) as exc_info: + ds.Mutation.addPlayer(player=None) + + assert "Received Null value for a Non-Null type JSON." in str(exc_info.value) + + +def test_json_value_input_with_none_list_in_dsl_argument(): + + ds = DSLSchema(schema) + + new_player = { + "name": "Bob", + "level": 9001, + "is_connected": True, + "score": 666.66, + "friends": None, + } + + query = ds.Mutation.addPlayer(player=new_player) + + print(str(query)) + + assert ( + str(query) + == """addPlayer( + player: {name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null} +)""" + ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 93de6c03..d18bb37d 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,5 +1,7 @@ import pytest from graphql import ( + GraphQLError, + GraphQLID, GraphQLInt, GraphQLList, GraphQLNonNull, @@ -23,6 +25,7 @@ DSLSubscription, DSLVariable, DSLVariableDefinitions, + ast_from_serialized_value_untyped, ast_from_value, dsl_gql, ) @@ -54,12 +57,38 @@ def test_ast_from_value_with_none(): def test_ast_from_value_with_undefined(): - assert ast_from_value(Undefined, GraphQLInt) is None + with pytest.raises(GraphQLError) as exc_info: + ast_from_value(Undefined, GraphQLInt) + + assert "Received Undefined value for type Int." in str(exc_info.value) + + +def test_ast_from_value_with_graphqlid(): + + assert ast_from_value("12345", GraphQLID) == IntValueNode(value="12345") + + +def test_ast_from_value_with_invalid_type(): + with pytest.raises(TypeError) as exc_info: + ast_from_value(4, None) + + assert "Unexpected input type: None." in str(exc_info.value) def test_ast_from_value_with_non_null_type_and_none(): typ = GraphQLNonNull(GraphQLInt) - assert ast_from_value(None, typ) is None + + with pytest.raises(GraphQLError) as exc_info: + ast_from_value(None, typ) + + assert "Received Null value for a Non-Null type Int." in str(exc_info.value) + + +def test_ast_from_serialized_value_untyped_typeerror(): + with pytest.raises(TypeError) as exc_info: + ast_from_serialized_value_untyped(GraphQLInt) + + assert "Cannot convert value to AST: Int." in str(exc_info.value) def test_variable_to_ast_type_passing_wrapping_type(): From 24ee0dd80d85f9aba596784c7b8ed5e977eb7d93 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 19 Nov 2021 18:13:25 +0100 Subject: [PATCH 06/21] Add parse_result feature --- gql/client.py | 55 +++- gql/utilities/__init__.py | 2 + gql/utilities/parse_result.py | 259 ++++++++++++++++++ .../test_custom_scalar_datetime.py | 16 +- .../test_custom_scalar_money.py | 12 +- tests/starwars/test_parse_results.py | 191 +++++++++++++ tests/starwars/test_subscription.py | 4 +- tests/test_async_client_validation.py | 2 +- 8 files changed, 524 insertions(+), 17 deletions(-) create mode 100644 gql/utilities/parse_result.py create mode 100644 tests/starwars/test_parse_results.py diff --git a/gql/client.py b/gql/client.py index 368193cc..32e62b6d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,6 +17,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .utilities import parse_result as parse_result_fn from .variable_values import serialize_variable_values @@ -48,6 +49,7 @@ def __init__( transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[Union[int, float]] = 10, + parse_results: bool = False, ): """Initialize the client with the given parameters. @@ -59,6 +61,8 @@ def __init__( :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. Passing None results in waiting forever for a response. + :param parse_results: Whether gql will try to parse the serialized output + sent by the backend. Can be used to unserialize custom scalars or enums. """ assert not ( type_def and introspection @@ -108,6 +112,8 @@ def __init__( # Enforced timeout of the execute function (only for async transports) self.execute_timeout = execute_timeout + self.parse_results = parse_results + def validate(self, document: DocumentNode): """:meta private:""" assert ( @@ -297,6 +303,7 @@ def _execute( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: bool = False, + parse_result: Optional[bool] = None, **kwargs, ) -> ExecutionResult: """Execute the provided document AST synchronously using @@ -307,6 +314,8 @@ def _execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -323,7 +332,7 @@ def _execute( operation_name=operation_name, ) - return self.transport.execute( + result = self.transport.execute( document, *args, variable_values=variable_values, @@ -331,6 +340,13 @@ def _execute( **kwargs, ) + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + result.data = parse_result_fn(self.client.schema, document, result.data) + + return result + def execute( self, document: DocumentNode, @@ -338,6 +354,7 @@ def execute( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: bool = False, + parse_result: Optional[bool] = None, **kwargs, ) -> Dict: """Execute the provided document AST synchronously using @@ -351,6 +368,8 @@ def execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -361,6 +380,7 @@ def execute( variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, + parse_result=parse_result, **kwargs, ) @@ -409,6 +429,7 @@ async def _subscribe( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: bool = False, + parse_result: Optional[bool] = None, **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: """Coroutine to subscribe asynchronously to the provided document AST @@ -423,6 +444,8 @@ async def _subscribe( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport subscribe method.""" @@ -456,7 +479,17 @@ async def _subscribe( try: async for result in inner_generator: + + if self.client.schema: + if parse_result or ( + parse_result is None and self.client.parse_results + ): + result.data = parse_result_fn( + self.client.schema, document, result.data + ) + yield result + finally: await inner_generator.aclose() @@ -467,6 +500,7 @@ async def subscribe( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: bool = False, + parse_result: Optional[bool] = None, **kwargs, ) -> AsyncGenerator[Dict, None]: """Coroutine to subscribe asynchronously to the provided document AST @@ -480,6 +514,8 @@ async def subscribe( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport subscribe method.""" @@ -489,6 +525,7 @@ async def subscribe( variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, + parse_result=parse_result, **kwargs, ) @@ -514,6 +551,7 @@ async def _execute( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: bool = False, + parse_result: Optional[bool] = None, **kwargs, ) -> ExecutionResult: """Coroutine to execute the provided document AST asynchronously using @@ -527,6 +565,8 @@ async def _execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -544,7 +584,7 @@ async def _execute( ) # Execute the query with the transport with a timeout - return await asyncio.wait_for( + result = await asyncio.wait_for( self.transport.execute( document, variable_values=variable_values, @@ -555,6 +595,13 @@ async def _execute( self.client.execute_timeout, ) + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + result.data = parse_result_fn(self.client.schema, document, result.data) + + return result + async def execute( self, document: DocumentNode, @@ -562,6 +609,7 @@ async def execute( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: bool = False, + parse_result: Optional[bool] = None, **kwargs, ) -> Dict: """Coroutine to execute the provided document AST asynchronously using @@ -575,6 +623,8 @@ async def execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -585,6 +635,7 @@ async def execute( variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, + parse_result=parse_result, **kwargs, ) diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index 68b80156..b7ab80e7 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,5 +1,7 @@ +from .parse_result import parse_result from .update_schema_scalars import update_schema_scalars __all__ = [ "update_schema_scalars", + "parse_result", ] diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py new file mode 100644 index 00000000..e390c76e --- /dev/null +++ b/gql/utilities/parse_result.py @@ -0,0 +1,259 @@ +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast + +from graphql import ( + IDLE, + REMOVE, + DocumentNode, + FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, + GraphQLEnumType, + GraphQLError, + GraphQLScalarType, + GraphQLSchema, + InlineFragmentNode, + NameNode, + OperationDefinitionNode, + SelectionSetNode, + TypeInfo, + TypeInfoVisitor, + Visitor, + get_nullable_type, + is_nullable_type, + visit, +) +from graphql.language.visitor import VisitorActionEnum + +# Equivalent to QUERY_DOCUMENT_KEYS but only for fields interesting to +# visit to parse the results +RESULT_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = { + "name": (), + "document": ("definitions",), + "operation_definition": ("name", "selection_set",), + "selection_set": ("selections",), + "field": ("alias", "name", "selection_set"), + "fragment_spread": ("name",), + "inline_fragment": ("type_condition", "selection_set"), + "fragment_definition": ("name", "type_condition", "selection_set",), +} + + +class ParseResultVisitor(Visitor): + def __init__( + self, + schema: GraphQLSchema, + document: DocumentNode, + type_info: TypeInfo, + result: Dict[str, Any], + visit_fragment: bool = False, + ): + + self.schema: GraphQLSchema = schema + self.document: DocumentNode = document + self.type_info: TypeInfo = type_info + self.result: Dict[str, Any] = result + self.visit_fragment: bool = visit_fragment + + self.result_stack: List[Any] = [] + + @property + def current_result(self): + try: + return self.result_stack[-1] + except IndexError: + return self.result + + def leave_name(self, node: NameNode, *_args: Any,) -> str: + return node.value + + @staticmethod + def leave_document(node: DocumentNode, *_args: Any) -> Dict[str, Any]: + results = cast(List[Dict[str, Any]], node.definitions) + return {k: v for result in results for k, v in result.items()} + + @staticmethod + def leave_operation_definition( + node: OperationDefinitionNode, *_args: Any + ) -> Dict[str, Any]: + selections = cast(List[Dict[str, Any]], node.selection_set) + return {k: v for s in selections for k, v in s.items()} + + @staticmethod + def leave_selection_set(node: SelectionSetNode, *_args: Any) -> Dict[str, Any]: + partial_results = cast(Dict[str, Any], node.selections) + return partial_results + + def enter_field( + self, node: FieldNode, *_args: Any, + ) -> Union[None, VisitorActionEnum, Dict[str, Any]]: + + name = node.alias.value if node.alias else node.name.value + + if isinstance(self.current_result, Mapping): + if name in self.current_result: + + result_value = self.current_result[name] + + # If the result for this field is a list, then we need + # to recursively visit the same node multiple times for each + # item in the list. + if ( + not isinstance(result_value, Mapping) + and isinstance(result_value, Iterable) + and not isinstance(result_value, str) + ): + visits: List[Dict[str, Any]] = [] + + for item in result_value: + + new_result = {name: item} + + inner_visit = cast( + Dict[str, Any], + visit( + node, + ParseResultVisitor( + self.schema, + self.document, + self.type_info, + new_result, + ), + visitor_keys=RESULT_DOCUMENT_KEYS, + ), + ) + + visits.append(inner_visit[name]) + + return {name: visits} + + # If the result for this field is not a list, then add it + # to the result stack so that it becomes the current_value + # for the next inner fields + self.result_stack.append(result_value) + + return IDLE + + else: + # Key not found in result. + # Should never happen in theory with a correct GraphQL backend + # Silently ignoring this field + return REMOVE + + elif self.current_result is None: + # Result was null for this field -> remove + return REMOVE + + raise GraphQLError( + f"Invalid result for container of field {name}: {self.current_result!r}" + ) + + def leave_field(self, node: FieldNode, *_args: Any,) -> Dict[str, Any]: + + name = cast(str, node.alias if node.alias else node.name) + + if self.current_result is None: + return {name: None} + elif node.selection_set is None: + + field_type = self.type_info.get_type() + if is_nullable_type(field_type): + field_type = get_nullable_type(field_type) # type: ignore + + if isinstance(field_type, (GraphQLScalarType, GraphQLEnumType)): + + parsed_value = field_type.parse_value(self.current_result) + else: + parsed_value = self.current_result + + return_value = {name: parsed_value} + else: + + partial_results = cast(List[Dict[str, Any]], node.selection_set) + + return_value = { + name: {k: v for pr in partial_results for k, v in pr.items()} + } + + # Go up a level in the result stack + self.result_stack.pop() + + return return_value + + # Fragments + + def enter_fragment_definition( + self, node: FragmentDefinitionNode, *_args: Any + ) -> Union[None, VisitorActionEnum]: + + if self.visit_fragment: + return IDLE + else: + return REMOVE + + @staticmethod + def leave_fragment_definition( + node: FragmentDefinitionNode, *_args: Any + ) -> Dict[str, Any]: + + selections = cast(List[Dict[str, Any]], node.selection_set) + return {k: v for s in selections for k, v in s.items()} + + def leave_fragment_spread( + self, node: FragmentSpreadNode, *_args: Any + ) -> Dict[str, Any]: + + fragment_name = node.name + + for definition in self.document.definitions: + if isinstance(definition, FragmentDefinitionNode): + if definition.name.value == fragment_name: + fragment_result = visit( + definition, + ParseResultVisitor( + self.schema, + self.document, + self.type_info, + self.current_result, + visit_fragment=True, + ), + visitor_keys=RESULT_DOCUMENT_KEYS, + ) + return fragment_result + + raise GraphQLError(f'Fragment "{fragment_name}" not found in schema!') + + @staticmethod + def leave_inline_fragment(node: InlineFragmentNode, *_args: Any) -> Dict[str, Any]: + + selections = cast(List[Dict[str, Any]], node.selection_set) + return {k: v for s in selections for k, v in s.items()} + + +def parse_result( + schema: GraphQLSchema, document: DocumentNode, result: Optional[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + """Unserialize a result received from a GraphQL backend. + + Given a schema, a query and a serialized result, + provide a new result with parsed values. + + If the result contains only built-in GraphQL scalars (String, Int, Float, ...) + then the parsed result should be unchanged. + + If the result contains custom scalars or enums, then those values + will be parsed with the parse_value method of the custom scalar or enum + definition in the schema.""" + + if result is None: + return None + + type_info = TypeInfo(schema) + + visited = visit( + document, + TypeInfoVisitor( + type_info, ParseResultVisitor(schema, document, type_info, result), + ), + ) + + return visited diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py index 25c6bb31..fe031dd3 100644 --- a/tests/custom_scalars/test_custom_scalar_datetime.py +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -112,7 +112,7 @@ def resolve_seconds(root, _info, interval): ) def test_shift_days(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) now = datetime.fromisoformat("2021-11-12T11:58:13.461161") @@ -128,7 +128,7 @@ def test_shift_days(): print(result) - assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") @pytest.mark.skipif( @@ -144,11 +144,11 @@ def test_shift_days_serialized_manually_in_query(): }""" ) - result = client.execute(query) + result = client.execute(query, parse_result=True) print(result) - assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") @pytest.mark.skipif( @@ -156,7 +156,7 @@ def test_shift_days_serialized_manually_in_query(): ) def test_shift_days_serialized_manually_in_variables(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") @@ -168,7 +168,7 @@ def test_shift_days_serialized_manually_in_variables(): print(result) - assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") @pytest.mark.skipif( @@ -176,7 +176,7 @@ def test_shift_days_serialized_manually_in_variables(): ) def test_latest(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) now = datetime.fromisoformat("2021-11-12T11:58:13.461161") in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") @@ -193,7 +193,7 @@ def test_latest(): print(result) - assert result["latest"] == in_five_days.isoformat() + assert result["latest"] == in_five_days @pytest.mark.skipif( diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py index 238308a9..014938db 100644 --- a/tests/custom_scalars/test_custom_scalar_money.py +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -140,7 +140,7 @@ async def subscribe_spend_all(_root, _info, money): def test_custom_scalar_in_output(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) query = gql("{balance}") @@ -148,7 +148,7 @@ def test_custom_scalar_in_output(): print(result) - assert result["balance"] == serialize_money(root_value) + assert result["balance"] == root_value def test_custom_scalar_in_input_query(): @@ -301,16 +301,18 @@ def test_custom_scalar_subscribe_in_input_variable_values_serialized(): variable_values = {"money": money_value} - expected_result = {"spend": {"amount": 10, "currency": "DM"}} + expected_result = {"spend": Money(10, "DM")} for result in client.subscribe( query, variable_values=variable_values, root_value=root_value, serialize_variables=True, + parse_result=True, ): print(f"result = {result!r}") - expected_result["spend"]["amount"] = expected_result["spend"]["amount"] - 1 + assert isinstance(result["spend"], Money) + expected_result["spend"] = Money(expected_result["spend"].amount - 1, "DM") assert expected_result == result @@ -588,7 +590,7 @@ async def test_custom_scalar_serialize_variables_sync_transport( server, transport = await make_sync_money_transport(aiohttp_server) def test_code(): - with Client(schema=schema, transport=transport,) as session: + with Client(schema=schema, transport=transport, parse_results=True) as session: query = gql("query myquery($money: Money) {toEuros(money: $money)}") diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py new file mode 100644 index 00000000..72b25177 --- /dev/null +++ b/tests/starwars/test_parse_results.py @@ -0,0 +1,191 @@ +import pytest +from graphql import GraphQLError + +from gql import gql +from gql.utilities import parse_result +from tests.starwars.schema import StarWarsSchema + + +def test_hero_name_and_friends_query(): + query = gql( + """ + query HeroNameAndFriendsQuery { + hero { + id + friends { + name + } + name + } + } + """ + ) + result = { + "hero": { + "id": "2001", + "friends": [ + {"name": "Luke Skywalker"}, + {"name": "Han Solo"}, + {"name": "Leia Organa"}, + ], + "name": "R2-D2", + } + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_key_not_found_in_result(): + + query = gql( + """ + { + hero { + id + } + } + """ + ) + + # Backend returned an invalid result without the hero key + # Should be impossible. In that case, we ignore the missing key + result = {} + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_invalid_result_raise_error(): + + query = gql( + """ + { + hero { + id + } + } + """ + ) + + result = {"hero": 5} + + with pytest.raises(GraphQLError) as exc_info: + + parse_result(StarWarsSchema, query, result) + + assert "Invalid result for container of field id: 5" in str(exc_info) + + +def test_fragment(): + + query = gql( + """ + query UseFragment { + luke: human(id: "1000") { + ...HumanFragment + } + leia: human(id: "1003") { + ...HumanFragment + } + } + fragment HumanFragment on Human { + name + homePlanet + } + """ + ) + + result = { + "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, + "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_fragment_not_found(): + + query = gql( + """ + query UseFragment { + luke: human(id: "1000") { + ...HumanFragment + } + } + """ + ) + + result = { + "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, + } + + with pytest.raises(GraphQLError) as exc_info: + + parse_result(StarWarsSchema, query, result) + + assert 'Fragment "HumanFragment" not found in schema!' in str(exc_info) + + +def test_return_none_if_result_is_none(): + + query = gql( + """ + query { + hero { + id + } + } + """ + ) + + result = None + + assert parse_result(StarWarsSchema, query, result) is None + + +def test_null_result_is_allowed(): + + query = gql( + """ + query { + hero { + id + } + } + """ + ) + + result = {"hero": None} + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_inline_fragment(): + + query = gql( + """ + query UseFragment { + luke: human(id: "1000") { + ... on Human { + name + homePlanet + } + } + } + """ + ) + + result = { + "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 3753ab2f..2516701f 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -53,7 +53,9 @@ async def test_subscription_support_using_client(): async with Client(schema=StarWarsSchema) as session: results = [ result["reviewAdded"] - async for result in session.subscribe(subs, variable_values=params) + async for result in session.subscribe( + subs, variable_values=params, parse_result=False + ) ] assert results == expected diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index 1402aa59..107bd6c2 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -112,7 +112,7 @@ async def test_async_client_validation( expected = [] async for result in session.subscribe( - subscription, variable_values=variable_values + subscription, variable_values=variable_values, parse_result=False ): review = result["reviewAdded"] From c3a8f42d6311bf1a098d04878ab898b4fc929bb5 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 19 Nov 2021 18:21:17 +0100 Subject: [PATCH 07/21] don't visit name nodes --- gql/utilities/parse_result.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index e390c76e..f613d952 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -12,7 +12,6 @@ GraphQLScalarType, GraphQLSchema, InlineFragmentNode, - NameNode, OperationDefinitionNode, SelectionSetNode, TypeInfo, @@ -29,12 +28,12 @@ RESULT_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = { "name": (), "document": ("definitions",), - "operation_definition": ("name", "selection_set",), + "operation_definition": ("selection_set",), "selection_set": ("selections",), - "field": ("alias", "name", "selection_set"), - "fragment_spread": ("name",), - "inline_fragment": ("type_condition", "selection_set"), - "fragment_definition": ("name", "type_condition", "selection_set",), + "field": ("alias", "selection_set"), + "fragment_spread": (), + "inline_fragment": ("selection_set",), + "fragment_definition": ("selection_set",), } @@ -63,9 +62,6 @@ def current_result(self): except IndexError: return self.result - def leave_name(self, node: NameNode, *_args: Any,) -> str: - return node.value - @staticmethod def leave_document(node: DocumentNode, *_args: Any) -> Dict[str, Any]: results = cast(List[Dict[str, Any]], node.definitions) @@ -149,7 +145,7 @@ def enter_field( def leave_field(self, node: FieldNode, *_args: Any,) -> Dict[str, Any]: - name = cast(str, node.alias if node.alias else node.name) + name = cast(str, node.alias.value if node.alias else node.name.value) if self.current_result is None: return {name: None} @@ -202,7 +198,7 @@ def leave_fragment_spread( self, node: FragmentSpreadNode, *_args: Any ) -> Dict[str, Any]: - fragment_name = node.name + fragment_name = node.name.value for definition in self.document.definitions: if isinstance(definition, FragmentDefinitionNode): From 584dd36bb848c934390e0202a37be0604ab415aa Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 19 Nov 2021 19:13:52 +0100 Subject: [PATCH 08/21] Add docs --- docs/usage/custom_scalars.rst | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst index baee441e..98e3236a 100644 --- a/docs/usage/custom_scalars.rst +++ b/docs/usage/custom_scalars.rst @@ -132,3 +132,41 @@ in the client to request the schema from the backend. result = await session.execute( query, variable_values=variable_values, serialize_variables=True ) + + # result["time"] is a string + +Custom Scalars in output +------------------------ + +By default, gql returns the serialized result from the backend without parsing +(except json unserialization to Python default types). + +if you want to convert the result of custom scalars to custom objects, +you can request gql to parse the results. + +- use :code:`Client(..., parse_results=True)` to request parsing for all queries +- use :code:`execute(..., parse_result=True)` or :code:`subscribe(..., parse_result=True)` if + you want gql to parse only the result of a single query. + +Same example as above, with result parsing enabled: + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, + variable_values=variable_values, + serialize_variables=True, + parse_result=True, + ) + + # now result["time"] type is a datetime instead of string From 01bea1d5bdc6670c17e5ba3e8f46f7378507814d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 21 Nov 2021 23:45:56 +0100 Subject: [PATCH 09/21] fix list of custom scalars + add enum tests --- gql/utilities/parse_result.py | 343 +++++++++++++----- tests/conftest.py | 1 + .../test_custom_scalar_money.py | 98 ++++- tests/custom_scalars/test_enum_colors.py | 202 +++++++++++ tests/starwars/test_parse_results.py | 2 +- 5 files changed, 549 insertions(+), 97 deletions(-) create mode 100644 tests/custom_scalars/test_enum_colors.py diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index f613d952..843632f0 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from graphql import ( @@ -9,49 +10,93 @@ FragmentSpreadNode, GraphQLEnumType, GraphQLError, + GraphQLInterfaceType, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, GraphQLScalarType, GraphQLSchema, + GraphQLType, InlineFragmentNode, + Node, OperationDefinitionNode, SelectionSetNode, TypeInfo, TypeInfoVisitor, Visitor, - get_nullable_type, - is_nullable_type, + print_ast, visit, ) from graphql.language.visitor import VisitorActionEnum +from graphql.pyutils import inspect + +log = logging.getLogger(__name__) # Equivalent to QUERY_DOCUMENT_KEYS but only for fields interesting to # visit to parse the results RESULT_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = { - "name": (), "document": ("definitions",), "operation_definition": ("selection_set",), "selection_set": ("selections",), - "field": ("alias", "selection_set"), - "fragment_spread": (), + "field": ("selection_set",), "inline_fragment": ("selection_set",), "fragment_definition": ("selection_set",), } +def _ignore_non_null(type_: GraphQLType): + """Removes the GraphQLNonNull wrappings around types.""" + if isinstance(type_, GraphQLNonNull): + return type_.of_type + else: + return type_ + + +def _get_fragment(document, fragment_name): + """Returns a fragment from the document.""" + for definition in document.definitions: + if isinstance(definition, FragmentDefinitionNode): + if definition.name.value == fragment_name: + return definition + + raise GraphQLError(f'Fragment "{fragment_name}" not found in document!') + + class ParseResultVisitor(Visitor): def __init__( self, schema: GraphQLSchema, document: DocumentNode, - type_info: TypeInfo, + node: Node, result: Dict[str, Any], + type_info: TypeInfo, visit_fragment: bool = False, + inside_list_level: int = 0, ): + """Recursive Implementation of a Visitor class to parse results + correspondind to a schema and a document. + + Using a TypeInfo class to get the node types during traversal. + If we reach a list in the results, then we parse each + item of the list recursively, traversing the same nodes + of the query again. + + During traversal, we keep the current position in the result + in the result_stack field. + + Alongside the field type, we calculate the "result type" + which is computed from the field type and the current + recursive level we are for this field + (:code:`inside_list_level` argument). + """ self.schema: GraphQLSchema = schema self.document: DocumentNode = document - self.type_info: TypeInfo = type_info + self.node: Node = node self.result: Dict[str, Any] = result + self.type_info: TypeInfo = type_info self.visit_fragment: bool = visit_fragment + self.inside_list_level = inside_list_level self.result_stack: List[Any] = [] @@ -79,87 +124,174 @@ def leave_selection_set(node: SelectionSetNode, *_args: Any) -> Dict[str, Any]: partial_results = cast(Dict[str, Any], node.selections) return partial_results + @staticmethod + def in_first_field(path): + return path.count("selections") <= 1 + + def get_current_result_type(self, path): + field_type = self.type_info.get_type() + + list_level = self.inside_list_level + + result_type = _ignore_non_null(field_type) + + if self.in_first_field(path): + + while list_level > 0: + assert isinstance(result_type, GraphQLList) + result_type = _ignore_non_null(result_type.of_type) + + list_level -= 1 + + return result_type + def enter_field( - self, node: FieldNode, *_args: Any, + self, + node: FieldNode, + key: str, + parent: Node, + path: List[Node], + ancestors: List[Node], ) -> Union[None, VisitorActionEnum, Dict[str, Any]]: name = node.alias.value if node.alias else node.name.value - if isinstance(self.current_result, Mapping): - if name in self.current_result: + if log.isEnabledFor(logging.DEBUG): + log.debug(f"Enter field {name}") + log.debug(f" path={path!r}") + log.debug(f" current_result={self.current_result!r}") - result_value = self.current_result[name] + if self.current_result is None: + # Result was null for this field -> remove + return REMOVE + + elif isinstance(self.current_result, Mapping): - # If the result for this field is a list, then we need - # to recursively visit the same node multiple times for each - # item in the list. - if ( - not isinstance(result_value, Mapping) - and isinstance(result_value, Iterable) - and not isinstance(result_value, str) - ): - visits: List[Dict[str, Any]] = [] - - for item in result_value: - - new_result = {name: item} - - inner_visit = cast( - Dict[str, Any], - visit( - node, - ParseResultVisitor( - self.schema, - self.document, - self.type_info, - new_result, - ), - visitor_keys=RESULT_DOCUMENT_KEYS, - ), - ) - - visits.append(inner_visit[name]) - - return {name: visits} - - # If the result for this field is not a list, then add it - # to the result stack so that it becomes the current_value - # for the next inner fields - self.result_stack.append(result_value) - - return IDLE - - else: + try: + result_value = self.current_result[name] + except KeyError: # Key not found in result. # Should never happen in theory with a correct GraphQL backend # Silently ignoring this field + log.debug(f"Key {name} not found in result --> REMOVE") return REMOVE - elif self.current_result is None: - # Result was null for this field -> remove - return REMOVE + log.debug(f" result_value={result_value}") + + # If the result for this field is a list, then we need + # to recursively visit the same node multiple times for each + # item in the list. + if ( + not isinstance(result_value, Mapping) + and isinstance(result_value, Iterable) + and not isinstance(result_value, str) + ): + + # We get the field_type from type_info + field_type = self.type_info.get_type() + + # We calculate a virtual "result type" depending on our recursion level. + result_type = self.get_current_result_type(path) + + if not isinstance(result_type, GraphQLList): + raise TypeError( + f"Received iterable result for a non-list type: {result_value}" + ) + + # Finding out the inner type of the list + inner_type = _ignore_non_null(result_type.of_type) + + if log.isEnabledFor(logging.DEBUG): + log.debug(" List detected:") + log.debug(f" field_type={inspect(field_type)}") + log.debug(f" result_type={inspect(result_type)}") + log.debug(f" inner_type={inspect(inner_type)}\n") + + visits: List[Dict[str, Any]] = [] + + # Get parent type + initial_type = self.type_info.get_parent_type() + assert isinstance( + initial_type, (GraphQLObjectType, GraphQLInterfaceType) + ) + + # Get parent SelectionSet node + new_node = ancestors[-1] + assert isinstance(new_node, SelectionSetNode) + + for item in result_value: + + new_result = {name: item} + + if log.isEnabledFor(logging.DEBUG): + log.debug(f" recursive new_result={new_result}") + log.debug(f" recursive ast={print_ast(node)}") + log.debug(f" recursive path={path!r}") + log.debug(f" recursive initial_type={initial_type!r}\n") + + # inside_list_level = (self.inside_list_level + 1) + # if self.in_first_field(path) else 1 + inside_list_level = self.inside_list_level + 1 + + inner_visit = parse_result_recursive( + self.schema, + self.document, + new_node, + new_result, + initial_type=initial_type, + inside_list_level=inside_list_level, + ) + log.debug(f" recursive result={inner_visit}\n") + + inner_visit = cast(List[Dict[str, Any]], inner_visit) + visits.append(inner_visit[0][name]) + + result_value = {name: visits} + log.debug(f" recursive visits final result = {result_value}\n") + return result_value + + # If the result for this field is not a list, then add it + # to the result stack so that it becomes the current_value + # for the next inner fields + self.result_stack.append(result_value) + + return IDLE raise GraphQLError( f"Invalid result for container of field {name}: {self.current_result!r}" ) - def leave_field(self, node: FieldNode, *_args: Any,) -> Dict[str, Any]: + def leave_field( + self, + node: FieldNode, + key: str, + parent: Node, + path: List[Node], + ancestors: List[Node], + ) -> Dict[str, Any]: name = cast(str, node.alias.value if node.alias else node.name.value) + log.debug(f"Leave field {name}") + if self.current_result is None: + + log.debug(f"Leave field {name}: returning None") return {name: None} + elif node.selection_set is None: field_type = self.type_info.get_type() - if is_nullable_type(field_type): - field_type = get_nullable_type(field_type) # type: ignore + result_type = self.get_current_result_type(path) + + if log.isEnabledFor(logging.DEBUG): + log.debug(f" field type of {name} is {inspect(field_type)}") + log.debug(f" result type of {name} is {inspect(result_type)}") - if isinstance(field_type, (GraphQLScalarType, GraphQLEnumType)): + assert isinstance(result_type, (GraphQLScalarType, GraphQLEnumType)) - parsed_value = field_type.parse_value(self.current_result) - else: - parsed_value = self.current_result + # Finally parsing a single scalar using the parse_value method + parsed_value = result_type.parse_value(self.current_result) return_value = {name: parsed_value} else: @@ -173,6 +305,8 @@ def leave_field(self, node: FieldNode, *_args: Any,) -> Dict[str, Any]: # Go up a level in the result stack self.result_stack.pop() + log.debug(f"Leave field {name}: returning {return_value}") + return return_value # Fragments @@ -181,6 +315,10 @@ def enter_fragment_definition( self, node: FragmentDefinitionNode, *_args: Any ) -> Union[None, VisitorActionEnum]: + if log.isEnabledFor(logging.DEBUG): + log.debug(f"Enter fragment definition {node.name.value}.") + log.debug(f"visit_fragment={self.visit_fragment!s}") + if self.visit_fragment: return IDLE else: @@ -200,23 +338,23 @@ def leave_fragment_spread( fragment_name = node.name.value - for definition in self.document.definitions: - if isinstance(definition, FragmentDefinitionNode): - if definition.name.value == fragment_name: - fragment_result = visit( - definition, - ParseResultVisitor( - self.schema, - self.document, - self.type_info, - self.current_result, - visit_fragment=True, - ), - visitor_keys=RESULT_DOCUMENT_KEYS, - ) - return fragment_result + log.debug(f"Start recursive fragment visit {fragment_name}") - raise GraphQLError(f'Fragment "{fragment_name}" not found in schema!') + fragment_node = _get_fragment(self.document, fragment_name) + + fragment_result = parse_result_recursive( + self.schema, + self.document, + fragment_node, + self.current_result, + visit_fragment=True, + ) + + log.debug( + f"Result of recursive fragment visit {fragment_name}: {fragment_result}" + ) + + return cast(Dict[str, Any], fragment_result) @staticmethod def leave_inline_fragment(node: InlineFragmentNode, *_args: Any) -> Dict[str, Any]: @@ -225,8 +363,43 @@ def leave_inline_fragment(node: InlineFragmentNode, *_args: Any) -> Dict[str, An return {k: v for s in selections for k, v in s.items()} +def parse_result_recursive( + schema: GraphQLSchema, + document: DocumentNode, + node: Node, + result: Optional[Dict[str, Any]], + initial_type: Optional[GraphQLType] = None, + inside_list_level: int = 0, + visit_fragment: bool = False, +) -> Any: + + if result is None: + return None + + type_info = TypeInfo(schema, initial_type=initial_type) + + visited = visit( + node, + TypeInfoVisitor( + type_info, + ParseResultVisitor( + schema, + document, + node, + result, + type_info=type_info, + inside_list_level=inside_list_level, + visit_fragment=visit_fragment, + ), + ), + visitor_keys=RESULT_DOCUMENT_KEYS, + ) + + return visited + + def parse_result( - schema: GraphQLSchema, document: DocumentNode, result: Optional[Dict[str, Any]] + schema: GraphQLSchema, document: DocumentNode, result: Optional[Dict[str, Any]], ) -> Optional[Dict[str, Any]]: """Unserialize a result received from a GraphQL backend. @@ -240,16 +413,4 @@ def parse_result( will be parsed with the parse_value method of the custom scalar or enum definition in the schema.""" - if result is None: - return None - - type_info = TypeInfo(schema) - - visited = visit( - document, - TypeInfoVisitor( - type_info, ParseResultVisitor(schema, document, type_info, result), - ), - ) - - return visited + return parse_result_recursive(schema, document, document, result) diff --git a/tests/conftest.py b/tests/conftest.py index 004fa9df..519738cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,6 +105,7 @@ async def go(app, *, port=None, **kwargs): # type: ignore "gql.transport.websockets", "gql.transport.phoenix_channel_websockets", "gql.dsl", + "gql.utilities.parse_result", ]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py index 014938db..7de3e8ac 100644 --- a/tests/custom_scalars/test_custom_scalar_money.py +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -11,6 +11,7 @@ GraphQLField, GraphQLFloat, GraphQLInt, + GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, @@ -82,9 +83,34 @@ def parse_money_literal( parse_literal=parse_money_literal, ) +root_value = { + "balance": Money(42, "DM"), + "friends_balance": [Money(12, "EUR"), Money(24, "EUR"), Money(150, "DM")], + "countries_balance": { + "Belgium": Money(15000, "EUR"), + "Luxembourg": Money(99999, "EUR"), + }, +} + def resolve_balance(root, _info): - return root + return root["balance"] + + +def resolve_friends_balance(root, _info): + return root["friends_balance"] + + +def resolve_countries_balance(root, _info): + return root["countries_balance"] + + +def resolve_belgium_balance(countries_balance, _info): + return countries_balance["Belgium"] + + +def resolve_luxembourg_balance(countries_balance, _info): + return countries_balance["Luxembourg"] def resolve_to_euros(_root, _info, money): @@ -97,6 +123,18 @@ def resolve_to_euros(_root, _info, money): raise ValueError("Cannot convert to euros: " + inspect(money)) +countriesBalance = GraphQLObjectType( + name="CountriesBalance", + fields={ + "Belgium": GraphQLField( + GraphQLNonNull(MoneyScalar), resolve=resolve_belgium_balance + ), + "Luxembourg": GraphQLField( + GraphQLNonNull(MoneyScalar), resolve=resolve_luxembourg_balance + ), + }, +) + queryType = GraphQLObjectType( name="RootQueryType", fields={ @@ -106,6 +144,12 @@ def resolve_to_euros(_root, _info, money): args={"money": GraphQLArgument(MoneyScalar)}, resolve=resolve_to_euros, ), + "friends_balance": GraphQLField( + GraphQLList(MoneyScalar), resolve=resolve_friends_balance + ), + "countries_balance": GraphQLField( + GraphQLNonNull(countriesBalance), resolve=resolve_countries_balance, + ), }, ) @@ -133,8 +177,6 @@ async def subscribe_spend_all(_root, _info, money): }, ) -root_value = Money(42, "DM") - schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) @@ -148,7 +190,53 @@ def test_custom_scalar_in_output(): print(result) - assert result["balance"] == root_value + assert result["balance"] == root_value["balance"] + + +def test_custom_scalar_in_output_embedded_fragments(): + + client = Client(schema=schema, parse_results=True) + + query = gql( + """ + fragment LuxMoneyInternal on CountriesBalance { + ... on CountriesBalance { + Luxembourg + } + } + query { + countries_balance { + Belgium + ...LuxMoney + } + } + fragment LuxMoney on CountriesBalance { + ...LuxMoneyInternal + } + """ + ) + + result = client.execute(query, root_value=root_value) + + print(result) + + belgium_money = result["countries_balance"]["Belgium"] + assert belgium_money == Money(15000, "EUR") + luxembourg_money = result["countries_balance"]["Luxembourg"] + assert luxembourg_money == Money(99999, "EUR") + + +def test_custom_scalar_list_in_output(): + + client = Client(schema=schema, parse_results=True) + + query = gql("{friends_balance}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["friends_balance"] == root_value["friends_balance"] def test_custom_scalar_in_input_query(): @@ -387,7 +475,7 @@ async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server print(result) - assert result["balance"] == serialize_money(root_value) + assert result["balance"] == serialize_money(root_value["balance"]) @pytest.mark.asyncio diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py new file mode 100644 index 00000000..0b235cb2 --- /dev/null +++ b/tests/custom_scalars/test_enum_colors.py @@ -0,0 +1,202 @@ +from enum import Enum + +from graphql import ( + GraphQLArgument, + GraphQLEnumType, + GraphQLField, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, +) + +from gql import Client, gql + + +class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + YELLOW = 3 + CYAN = 4 + MAGENTA = 5 + + +RED = Color.RED +GREEN = Color.GREEN +BLUE = Color.BLUE +YELLOW = Color.YELLOW +CYAN = Color.CYAN +MAGENTA = Color.MAGENTA + +ALL_COLORS = [c for c in Color] + +ColorType = GraphQLEnumType("Color", {c.name: c for c in Color}) + + +def resolve_opposite(_root, _info, color): + opposite_colors = { + RED: CYAN, + GREEN: MAGENTA, + BLUE: YELLOW, + YELLOW: BLUE, + CYAN: RED, + MAGENTA: GREEN, + } + + return opposite_colors[color] + + +def resolve_all(_root, _info): + return ALL_COLORS + + +list_of_list_of_list = [[[RED, GREEN], [GREEN, BLUE]], [[YELLOW, CYAN], [MAGENTA, RED]]] + + +def resolve_list_of_list_of_list(_root, _info): + return list_of_list_of_list + + +def resolve_list_of_list(_root, _info): + return list_of_list_of_list[0] + + +def resolve_list(_root, _info): + return list_of_list_of_list[0][0] + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "all": GraphQLField(GraphQLList(ColorType), resolve=resolve_all,), + "opposite": GraphQLField( + ColorType, + args={"color": GraphQLArgument(ColorType)}, + resolve=resolve_opposite, + ), + "list_of_list_of_list": GraphQLField( + GraphQLNonNull( + GraphQLList( + GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLList(ColorType)))) + ) + ), + resolve=resolve_list_of_list_of_list, + ), + "list_of_list": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLList(ColorType)))), + resolve=resolve_list_of_list, + ), + "list": GraphQLField( + GraphQLNonNull(GraphQLList(ColorType)), resolve=resolve_list, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +def test_parse_value_enum(): + + result = ColorType.parse_value("RED") + + print(result) + + assert isinstance(result, Color) + assert result is RED + + +def test_serialize_enum(): + + result = ColorType.serialize(RED) + + print(result) + + assert result == "RED" + + +def test_get_all_colors(): + + query = gql("{all}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + all_colors = result["all"] + + assert all_colors == ALL_COLORS + + +def test_opposite_color(): + + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + client = Client(schema=schema, parse_results=True) + + variable_values = { + "color": RED, + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN + + +def test_list(): + + query = gql("{list}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + big_list = result["list"] + + assert big_list == list_of_list_of_list[0][0] + + +def test_list_of_list(): + + query = gql("{list_of_list}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + big_list = result["list_of_list"] + + assert big_list == list_of_list_of_list[0] + + +def test_list_of_list_of_list(): + + query = gql("{list_of_list_of_list}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + big_list = result["list_of_list_of_list"] + + assert big_list == list_of_list_of_list diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py index 72b25177..23073839 100644 --- a/tests/starwars/test_parse_results.py +++ b/tests/starwars/test_parse_results.py @@ -128,7 +128,7 @@ def test_fragment_not_found(): parse_result(StarWarsSchema, query, result) - assert 'Fragment "HumanFragment" not found in schema!' in str(exc_info) + assert 'Fragment "HumanFragment" not found in document!' in str(exc_info) def test_return_none_if_result_is_none(): From 77b70234737609dc718047af19e02b08e5d0d42f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 00:20:52 +0100 Subject: [PATCH 10/21] Fix starting recursive descent deep in query + fix custom scalars serializing as lists --- gql/utilities/parse_result.py | 27 ++++++++++--------- .../custom_scalars/test_custom_scalar_json.py | 2 +- tests/starwars/test_query.py | 2 +- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 843632f0..112d8858 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -8,13 +8,11 @@ FieldNode, FragmentDefinitionNode, FragmentSpreadNode, - GraphQLEnumType, GraphQLError, GraphQLInterfaceType, GraphQLList, GraphQLNonNull, GraphQLObjectType, - GraphQLScalarType, GraphQLSchema, GraphQLType, InlineFragmentNode, @@ -24,6 +22,7 @@ TypeInfo, TypeInfoVisitor, Visitor, + is_leaf_type, print_ast, visit, ) @@ -178,6 +177,12 @@ def enter_field( log.debug(f" result_value={result_value}") + # We get the field_type from type_info + field_type = self.type_info.get_type() + + # We calculate a virtual "result type" depending on our recursion level. + result_type = self.get_current_result_type(path) + # If the result for this field is a list, then we need # to recursively visit the same node multiple times for each # item in the list. @@ -185,18 +190,15 @@ def enter_field( not isinstance(result_value, Mapping) and isinstance(result_value, Iterable) and not isinstance(result_value, str) + and not is_leaf_type(result_type) ): - # We get the field_type from type_info - field_type = self.type_info.get_type() - - # We calculate a virtual "result type" depending on our recursion level. - result_type = self.get_current_result_type(path) - + """ if not isinstance(result_type, GraphQLList): raise TypeError( f"Received iterable result for a non-list type: {result_value}" ) + """ # Finding out the inner type of the list inner_type = _ignore_non_null(result_type.of_type) @@ -229,9 +231,10 @@ def enter_field( log.debug(f" recursive path={path!r}") log.debug(f" recursive initial_type={initial_type!r}\n") - # inside_list_level = (self.inside_list_level + 1) - # if self.in_first_field(path) else 1 - inside_list_level = self.inside_list_level + 1 + if self.in_first_field(path): + inside_list_level = self.inside_list_level + 1 + else: + inside_list_level = 1 inner_visit = parse_result_recursive( self.schema, @@ -288,7 +291,7 @@ def leave_field( log.debug(f" field type of {name} is {inspect(field_type)}") log.debug(f" result type of {name} is {inspect(result_type)}") - assert isinstance(result_type, (GraphQLScalarType, GraphQLEnumType)) + assert is_leaf_type(result_type) # Finally parsing a single scalar using the parse_value method parsed_value = result_type.parse_value(self.current_result) diff --git a/tests/custom_scalars/test_custom_scalar_json.py b/tests/custom_scalars/test_custom_scalar_json.py index 80f99850..9659d0a5 100644 --- a/tests/custom_scalars/test_custom_scalar_json.py +++ b/tests/custom_scalars/test_custom_scalar_json.py @@ -94,7 +94,7 @@ def resolve_add_player(root, _info, player): def test_json_value_output(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) query = gql("query {players}") diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 62890222..520018c1 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -107,7 +107,7 @@ def test_nested_query(client): ], } } - result = client.execute(query) + result = client.execute(query, parse_result=False) assert result == expected From e1deb74d8adffddaad29262c044518e88e5000bc Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 14:23:55 +0100 Subject: [PATCH 11/21] Update docs + add some enum tests --- docs/usage/custom_scalars.rst | 172 ------------ docs/usage/custom_scalars_and_enums.rst | 328 +++++++++++++++++++++++ docs/usage/index.rst | 2 +- tests/custom_scalars/test_enum_colors.py | 43 ++- 4 files changed, 371 insertions(+), 174 deletions(-) delete mode 100644 docs/usage/custom_scalars.rst create mode 100644 docs/usage/custom_scalars_and_enums.rst diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst deleted file mode 100644 index 98e3236a..00000000 --- a/docs/usage/custom_scalars.rst +++ /dev/null @@ -1,172 +0,0 @@ -Custom Scalars -============== - -Scalar types represent primitive values at the leaves of a query. - -GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend -can add additional custom scalars to its schema to better express values in their data model. - -For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. - -The schema will then only contain: - -.. code-block:: python - - scalar Datetime - -When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), -their values need to be serialized to be composed -of only built-in scalars, then at the destination the serialized values will be parsed again to -be able to represent the scalar in its local internal representation. - -Because this serialization/unserialization is dependent on the language used at both sides, it is not -described in the schema and needs to be defined independently at both sides (client, backend). - -A custom scalar value can have two different representations during its transport: - - - as a serialized value (usually as json): - - * in the results sent by the backend - * in the variables sent by the client alongside the query - - - as "literal" inside the query itself sent by the client - -To define a custom scalar, you need 3 methods: - - - a :code:`serialize` method used: - - * by the backend to serialize a custom scalar output in the result - * by the client to serialize a custom scalar input in the variables - - - a :code:`parse_value` method used: - - * by the backend to unserialize custom scalars inputs in the variables sent by the client - * by the client to unserialize custom scalars outputs from the results - - - a :code:`parse_literal` method used: - - * by the backend to unserialize custom scalars inputs inside the query itself - -To define a custom scalar object, we define a :code:`GraphQLScalarType` from graphql-core with -its name and the implementation of the above methods. - -Example for Datetime: - -.. code-block:: python - - from datetime import datetime - from typing import Any, Dict, Optional - - from graphql import GraphQLScalarType, ValueNode - from graphql.utilities import value_from_ast_untyped - - - def serialize_datetime(value: Any) -> str: - return value.isoformat() - - - def parse_datetime_value(value: Any) -> datetime: - return datetime.fromisoformat(value) - - - def parse_datetime_literal( - value_node: ValueNode, variables: Optional[Dict[str, Any]] = None - ) -> datetime: - ast_value = value_from_ast_untyped(value_node, variables) - return parse_datetime_value(ast_value) - - - DatetimeScalar = GraphQLScalarType( - name="Datetime", - serialize=serialize_datetime, - parse_value=parse_datetime_value, - parse_literal=parse_datetime_literal, - ) - -Custom Scalars in inputs ------------------------- - -To provide custom scalars in input with gql, you can: - -- serialize the scalar yourself as "literal" in the query: - -.. code-block:: python - - query = gql( - """{ - shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) - }""" - ) - -- serialize the scalar yourself in a variable: - -.. code-block:: python - - query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - - variable_values = { - "time": "2021-11-12T11:58:13.461161", - } - - result = client.execute(query, variable_values=variable_values) - -- add a custom scalar to the schema with :func:`update_schema_scalars ` - and execute the query with :code:`serialize_variables=True` - and gql will serialize the variable values from a Python object representation. - -For this, you need to provide a schema or set :code:`fetch_schema_from_transport=True` -in the client to request the schema from the backend. - -.. code-block:: python - - from gql.utilities import update_schema_scalars - - async with Client(transport=transport, fetch_schema_from_transport=True) as session: - - update_schema_scalars(session.client.schema, [DatetimeScalar]) - - query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - - variable_values = {"time": datetime.now()} - - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) - - # result["time"] is a string - -Custom Scalars in output ------------------------- - -By default, gql returns the serialized result from the backend without parsing -(except json unserialization to Python default types). - -if you want to convert the result of custom scalars to custom objects, -you can request gql to parse the results. - -- use :code:`Client(..., parse_results=True)` to request parsing for all queries -- use :code:`execute(..., parse_result=True)` or :code:`subscribe(..., parse_result=True)` if - you want gql to parse only the result of a single query. - -Same example as above, with result parsing enabled: - -.. code-block:: python - - from gql.utilities import update_schema_scalars - - async with Client(transport=transport, fetch_schema_from_transport=True) as session: - - update_schema_scalars(session.client.schema, [DatetimeScalar]) - - query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - - variable_values = {"time": datetime.now()} - - result = await session.execute( - query, - variable_values=variable_values, - serialize_variables=True, - parse_result=True, - ) - - # now result["time"] type is a datetime instead of string diff --git a/docs/usage/custom_scalars_and_enums.rst b/docs/usage/custom_scalars_and_enums.rst new file mode 100644 index 00000000..7758dc7a --- /dev/null +++ b/docs/usage/custom_scalars_and_enums.rst @@ -0,0 +1,328 @@ +Custom scalars and enums +======================== + +.. _custom_scalars: + +Custom scalars +-------------- + +Scalar types represent primitive values at the leaves of a query. + +GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend +can add additional custom scalars to its schema to better express values in their data model. + +For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. + +The schema will then only contain:: + + scalar Datetime + +When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), +their values need to be serialized to be composed +of only built-in scalars, then at the destination the serialized values will be parsed again to +be able to represent the scalar in its local internal representation. + +Because this serialization/unserialization is dependent on the language used at both sides, it is not +described in the schema and needs to be defined independently at both sides (client, backend). + +A custom scalar value can have two different representations during its transport: + + - as a serialized value (usually as json): + + * in the results sent by the backend + * in the variables sent by the client alongside the query + + - as "literal" inside the query itself sent by the client + +To define a custom scalar, you need 3 methods: + + - a :code:`serialize` method used: + + * by the backend to serialize a custom scalar output in the result + * by the client to serialize a custom scalar input in the variables + + - a :code:`parse_value` method used: + + * by the backend to unserialize custom scalars inputs in the variables sent by the client + * by the client to unserialize custom scalars outputs from the results + + - a :code:`parse_literal` method used: + + * by the backend to unserialize custom scalars inputs inside the query itself + +To define a custom scalar object, graphql-core provides the :code:`GraphQLScalarType` class +which contains the implementation of the above methods. + +Example for Datetime: + +.. code-block:: python + + from datetime import datetime + from typing import Any, Dict, Optional + + from graphql import GraphQLScalarType, ValueNode + from graphql.utilities import value_from_ast_untyped + + + def serialize_datetime(value: Any) -> str: + return value.isoformat() + + + def parse_datetime_value(value: Any) -> datetime: + return datetime.fromisoformat(value) + + + def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None + ) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + return parse_datetime_value(ast_value) + + + DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, + ) + +If you get your schema from a "schema.graphql" file or from introspection, +then the generated schema in the gql Client will contain default :code:`GraphQLScalarType` instances +where the serialize and parse_value methods simply return the serialized value without modification. + +In that case, if you want gql to parse custom scalars to a more useful Python representation, +or to serialize custom scalars variables from a Python representation, +then you can use the :func:`update_schema_scalars ` method +to modify the definition of the scalars in your schema so that gql could do the parsing/serialization. + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + with open('path/to/schema.graphql') as f: + schema_str = f.read() + + client = Client(schema=schema_str, ...) + + update_schema_scalars(client.schema, [DatetimeScalar]) + +.. _enums: + +Enums +----- + +GraphQL Enum types are a special kind of scalar that is restricted to a particular set of allowed values. + +For example, the schema may have a Color enum and contain:: + + enum Color { + RED + GREEN + BLUE + } + +Graphql-core provides the :code:`GraphQLEnumType` class to define an enum in the schema +(See `graphql-core schema building docs`_). + +This class defines how the enum is serialized and parsed. + +If you get your schema from a "schema.graphql" file or from introspection, +then the generated schema in the gql Client will contain default :code:`GraphQLEnumType` instances +which should serialize/parse enums to/from its String representation (the :code:`RED` enum +will be serialized to :code:`'RED'`). + +You may want to parse enums to convert them to Python Enum types. +In that case, you can use the :func:`update_schema_enum ` +to modify the default :code:`GraphQLEnumType` to use your defined Enum. + +Example: + +.. code-block:: python + + from enum import Enum + from gql.utilities import update_schema_enum + + class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + + with open('path/to/schema.graphql') as f: + schema_str = f.read() + + client = Client(schema=schema_str, ...) + + update_schema_enum(client.schema, 'Color', Color) + +Serializing Inputs +------------------ + +To provide custom scalars and/or enums in inputs with gql, you can: + +- serialize the inputs manually +- let gql serialize the inputs using the custom scalars and enums defined in the schema + +Manually +^^^^^^^^ + +You can serialize inputs yourself: + + - in the query itself + - in variables + +This has the advantage that you don't need a schema... + +In the query +"""""""""""" + +- custom scalar: + +.. code-block:: python + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + +- enum: + +.. code-block:: python + + query = gql("{opposite(color: RED)}") + +In a variable +""""""""""""" + +- custom scalar: + +.. code-block:: python + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + +- enum: + +.. code-block:: python + + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + variable_values = { + "color": 'RED', + } + + result = client.execute(query, variable_values=variable_values) + +Automatically +^^^^^^^^^^^^^ + +If you have custom scalar and/or enums defined in your schema +(See: :ref:`custom_scalars` and :ref:`enums`), +then you can request gql to serialize your variables automatically. + +- use :code:`Client(..., serialize_variables=True)` to request serializing variables for all queries +- use :code:`execute(..., serialize_variables=True)` or :code:`subscribe(..., serialize_variables=True)` if + you want gql to serialize the variables only for a single query. + +Examples: + +- custom scalars: + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + from .myscalars import DatetimeScalar + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # We update the schema we got from introspection with our custom scalar type + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + # In the query, the custom scalar in the input is set to a variable + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + # the argument for time is a datetime instance + variable_values = {"time": datetime.now()} + + # we execute the query with serialize_variables set to True + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + +- enums: + +.. code-block:: python + + from gql.utilities import update_schema_enum + + from .myenums import Color + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # We update the schema we got from introspection with our custom enum + update_schema_enum(session.client.schema, 'Color', Color) + + # In the query, the enum in the input is set to a variable + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + # the argument for time is an instance of our Enum type + variable_values = { + "color": Color.RED, + } + + # we execute the query with serialize_variables set to True + result = client.execute(query, variable_values=variable_values) + +Parsing output +-------------- + +By default, gql returns the serialized result from the backend without parsing +(except json unserialization to Python default types). + +if you want to convert the result of custom scalars to custom objects, +you can request gql to parse the results. + +- use :code:`Client(..., parse_results=True)` to request parsing for all queries +- use :code:`execute(..., parse_result=True)` or :code:`subscribe(..., parse_result=True)` if + you want gql to parse only the result of a single query. + +Same example as above, with result parsing enabled: + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, + variable_values=variable_values, + serialize_variables=True, + parse_result=True, + ) + + # now result["time"] type is a datetime instead of string + +.. _graphql-core schema building docs: https://graphql-core-3.readthedocs.io/en/latest/usage/schema.html diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 4a38093a..eebf9fd2 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -10,4 +10,4 @@ Usage variables headers file_upload - custom_scalars + custom_scalars_and_enums diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 0b235cb2..76876a3c 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -130,7 +130,25 @@ def test_get_all_colors(): assert all_colors == ALL_COLORS -def test_opposite_color(): +def test_opposite_color_literal(): + + client = Client(schema=schema, parse_results=True) + + query = gql("{opposite(color: RED)}") + + result = client.execute(query) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN + + +def test_opposite_color_variable_serialized_manually(): + + client = Client(schema=schema, parse_results=True) query = gql( """ @@ -139,8 +157,31 @@ def test_opposite_color(): }""" ) + variable_values = { + "color": "RED", + } + + result = client.execute(query, variable_values=variable_values) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN + + +def test_opposite_color_variable_serialized_by_gql(): + client = Client(schema=schema, parse_results=True) + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + variable_values = { "color": RED, } From fd96eaf37a2b688aa79eb3a80bb44a49bd0efe51 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 14:44:32 +0100 Subject: [PATCH 12/21] rename custom scalar test files --- .../{test_custom_scalar_datetime.py => test_datetime.py} | 0 tests/custom_scalars/{test_custom_scalar_json.py => test_json.py} | 0 .../custom_scalars/{test_custom_scalar_money.py => test_money.py} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/custom_scalars/{test_custom_scalar_datetime.py => test_datetime.py} (100%) rename tests/custom_scalars/{test_custom_scalar_json.py => test_json.py} (100%) rename tests/custom_scalars/{test_custom_scalar_money.py => test_money.py} (100%) diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_datetime.py similarity index 100% rename from tests/custom_scalars/test_custom_scalar_datetime.py rename to tests/custom_scalars/test_datetime.py diff --git a/tests/custom_scalars/test_custom_scalar_json.py b/tests/custom_scalars/test_json.py similarity index 100% rename from tests/custom_scalars/test_custom_scalar_json.py rename to tests/custom_scalars/test_json.py diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_money.py similarity index 100% rename from tests/custom_scalars/test_custom_scalar_money.py rename to tests/custom_scalars/test_money.py From c564d9e1bd48295932db55d8014de300b06f4a97 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 16:50:43 +0100 Subject: [PATCH 13/21] Add update_schema_enum function --- gql/utilities/__init__.py | 2 + gql/utilities/update_schema_enum.py | 69 ++++++++++++++++++++++++ tests/custom_scalars/test_enum_colors.py | 42 +++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 gql/utilities/update_schema_enum.py diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index b7ab80e7..bc7552df 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,7 +1,9 @@ from .parse_result import parse_result +from .update_schema_enum import update_schema_enum from .update_schema_scalars import update_schema_scalars __all__ = [ "update_schema_scalars", + "update_schema_enum", "parse_result", ] diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py new file mode 100644 index 00000000..b4ba71ba --- /dev/null +++ b/gql/utilities/update_schema_enum.py @@ -0,0 +1,69 @@ +from enum import Enum +from typing import Any, Dict, Mapping, Type, Union, cast + +from graphql import GraphQLEnumType, GraphQLSchema + + +def update_schema_enum( + schema: GraphQLSchema, + name: str, + values: Union[Dict[str, Any], Type[Enum]], + use_enum_values: bool = False, +): + """Update in the schema the GraphQLEnumType corresponding to the given name. + + Example:: + + from enum import Enum + + class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + + update_schema_enum(schema, 'Color', Color) + + :param schema: a GraphQL Schema already containing the GraphQLEnumType type. + :param name: the name of the enum in the GraphQL schema + :values: Either a Python Enum or a dict of values. The keys of the provided + values should correspond to the keys of the existing enum in the schema. + :use_enum_values: By default, we configure the GraphQLEnumType to serialize + to enum instances (ie: .parse_value() returns Color.RED). + If use_enum_values is set to True, then .parse_value() returns 0. + use_enum_values=True is the defaut behaviour when passing an Enum + to a GraphQLEnumType. + """ + + # Convert Enum values to Dict + if isinstance(values, type): + if issubclass(values, Enum): + values = cast(Type[Enum], values) + if use_enum_values: + values = {enum.name: enum.value for enum in values} + else: + values = {enum.name: enum for enum in values} + + if not isinstance(values, Mapping): + raise TypeError(f"Invalid type for enum values: {type(values)}") + + # Find enum type in schema + schema_enum = schema.get_type(name) + + if schema_enum is None: + raise KeyError(f"Enum {name} not found in schema!") + + if not isinstance(schema_enum, GraphQLEnumType): + raise TypeError( + f'The type "{name}" is not a GraphQLEnumType, it is a {type(schema_enum)}' + ) + + # Replace all enum values + for enum_name, enum_value in schema_enum.values.items(): + try: + enum_value.value = values[enum_name] + except KeyError: + raise KeyError(f'Enum key "{enum_name}" not found in provided values!') + + # Delete the _value_lookup cached property + if "_value_lookup" in schema_enum.__dict__: + del schema_enum.__dict__["_value_lookup"] diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 76876a3c..f76a581c 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -1,5 +1,6 @@ from enum import Enum +import pytest from graphql import ( GraphQLArgument, GraphQLEnumType, @@ -11,6 +12,7 @@ ) from gql import Client, gql +from gql.utilities import update_schema_enum class Color(Enum): @@ -241,3 +243,43 @@ def test_list_of_list_of_list(): big_list = result["list_of_list_of_list"] assert big_list == list_of_list_of_list + + +def test_update_schema_enum(): + + assert schema.get_type("Color").parse_value("RED") == Color.RED + + # Using values + + update_schema_enum(schema, "Color", Color, use_enum_values=True) + + assert schema.get_type("Color").parse_value("RED") == 0 + assert schema.get_type("Color").serialize(1) == "GREEN" + + update_schema_enum(schema, "Color", Color) + + assert schema.get_type("Color").parse_value("RED") == Color.RED + assert schema.get_type("Color").serialize(Color.RED) == "RED" + + +def test_update_schema_enum_errors(): + + with pytest.raises(KeyError) as exc_info: + update_schema_enum(schema, "Corlo", Color) + + assert "Enum Corlo not found in schema!" in str(exc_info) + + with pytest.raises(TypeError) as exc_info: + update_schema_enum(schema, "Color", 6) + + assert "Invalid type for enum values: " in str(exc_info) + + with pytest.raises(TypeError) as exc_info: + update_schema_enum(schema, "RootQueryType", Color) + + assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str(exc_info) + + with pytest.raises(KeyError) as exc_info: + update_schema_enum(schema, "Color", {"RED": Color.RED}) + + assert 'Enum key "GREEN" not found in provided values!' in str(exc_info) From 7d83add2aec39c9c4dc8dd7496b7c5497d8247c1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 17:27:58 +0100 Subject: [PATCH 14/21] Allow to define serialize_variables argument in Client --- gql/client.py | 67 ++++++++++++++++----------- tests/custom_scalars/test_datetime.py | 6 +-- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/gql/client.py b/gql/client.py index 32e62b6d..c6922ca7 100644 --- a/gql/client.py +++ b/gql/client.py @@ -49,6 +49,7 @@ def __init__( transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[Union[int, float]] = 10, + serialize_variables: bool = False, parse_results: bool = False, ): """Initialize the client with the given parameters. @@ -61,6 +62,8 @@ def __init__( :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. Passing None results in waiting forever for a response. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. :param parse_results: Whether gql will try to parse the serialized output sent by the backend. Can be used to unserialize custom scalars or enums. """ @@ -112,6 +115,7 @@ def __init__( # Enforced timeout of the execute function (only for async transports) self.execute_timeout = execute_timeout + self.serialize_variables = serialize_variables self.parse_results = parse_results def validate(self, document: DocumentNode): @@ -302,7 +306,7 @@ def _execute( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, ) -> ExecutionResult: @@ -324,13 +328,16 @@ def _execute( self.client.validate(document) # Parse variable values for custom scalars if requested - if serialize_variables and variable_values is not None: - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + if variable_values is not None: + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) result = self.transport.execute( document, @@ -353,7 +360,7 @@ def execute( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, ) -> Dict: @@ -428,7 +435,7 @@ async def _subscribe( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: @@ -454,13 +461,16 @@ async def _subscribe( self.client.validate(document) # Parse variable values for custom scalars if requested - if serialize_variables and variable_values is not None: - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + if variable_values is not None: + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) # Subscribe to the transport inner_generator: AsyncGenerator[ @@ -499,7 +509,7 @@ async def subscribe( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, ) -> AsyncGenerator[Dict, None]: @@ -550,7 +560,7 @@ async def _execute( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, ) -> ExecutionResult: @@ -575,13 +585,16 @@ async def _execute( self.client.validate(document) # Parse variable values for custom scalars if requested - if serialize_variables and variable_values is not None: - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + if variable_values is not None: + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) # Execute the query with the transport with a timeout result = await asyncio.wait_for( @@ -608,7 +621,7 @@ async def execute( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, ) -> Dict: diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index fe031dd3..169ce076 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -112,7 +112,7 @@ def resolve_seconds(root, _info, interval): ) def test_shift_days(): - client = Client(schema=schema, parse_results=True) + client = Client(schema=schema, parse_results=True, serialize_variables=True) now = datetime.fromisoformat("2021-11-12T11:58:13.461161") @@ -122,9 +122,7 @@ def test_shift_days(): "time": now, } - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, variable_values=variable_values) print(result) From 16f4787fecea9756981dcce463afc6dc2b476d0b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 18:05:13 +0100 Subject: [PATCH 15/21] Add update_schema_scalar to update a single scalar with name + raises TypeError and KeyError instead of GraphQLErrors --- docs/usage/custom_scalars_and_enums.rst | 11 ++++-- gql/utilities/__init__.py | 3 +- gql/utilities/update_schema_enum.py | 4 +- gql/utilities/update_schema_scalars.py | 49 +++++++++++++++++-------- tests/custom_scalars/test_money.py | 32 +++++++++++++--- 5 files changed, 70 insertions(+), 29 deletions(-) diff --git a/docs/usage/custom_scalars_and_enums.rst b/docs/usage/custom_scalars_and_enums.rst index 7758dc7a..659b4502 100644 --- a/docs/usage/custom_scalars_and_enums.rst +++ b/docs/usage/custom_scalars_and_enums.rst @@ -92,19 +92,22 @@ where the serialize and parse_value methods simply return the serialized value w In that case, if you want gql to parse custom scalars to a more useful Python representation, or to serialize custom scalars variables from a Python representation, -then you can use the :func:`update_schema_scalars ` method -to modify the definition of the scalars in your schema so that gql could do the parsing/serialization. +then you can use the :func:`update_schema_scalars ` +or :func:`update_schema_scalar ` methods +to modify the definition of a scalar in your schema so that gql could do the parsing/serialization. .. code-block:: python - from gql.utilities import update_schema_scalars + from gql.utilities import update_schema_scalar with open('path/to/schema.graphql') as f: schema_str = f.read() client = Client(schema=schema_str, ...) - update_schema_scalars(client.schema, [DatetimeScalar]) + update_schema_scalar(client.schema, "Datetime", DatetimeScalar) + + # or update_schema_scalars(client.schema, [DatetimeScalar]) .. _enums: diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index bc7552df..3ebdf81f 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,9 +1,10 @@ from .parse_result import parse_result from .update_schema_enum import update_schema_enum -from .update_schema_scalars import update_schema_scalars +from .update_schema_scalars import update_schema_scalar, update_schema_scalars __all__ = [ "update_schema_scalars", + "update_schema_scalar", "update_schema_enum", "parse_result", ] diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py index b4ba71ba..80c73862 100644 --- a/gql/utilities/update_schema_enum.py +++ b/gql/utilities/update_schema_enum.py @@ -25,9 +25,9 @@ class Color(Enum): :param schema: a GraphQL Schema already containing the GraphQLEnumType type. :param name: the name of the enum in the GraphQL schema - :values: Either a Python Enum or a dict of values. The keys of the provided + :param values: Either a Python Enum or a dict of values. The keys of the provided values should correspond to the keys of the existing enum in the schema. - :use_enum_values: By default, we configure the GraphQLEnumType to serialize + :param use_enum_values: By default, we configure the GraphQLEnumType to serialize to enum instances (ie: .parse_value() returns Color.RED). If use_enum_values is set to True, then .parse_value() returns 0. use_enum_values=True is the defaut behaviour when passing an Enum diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index d5434c6b..8e34e7ad 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -1,6 +1,35 @@ from typing import Iterable, List -from graphql import GraphQLError, GraphQLScalarType, GraphQLSchema +from graphql import GraphQLScalarType, GraphQLSchema + + +def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): + """Update the scalar in a schema with the scalar provided. + + This can be used to update the default Custom Scalar implementation + when the schema has been provided from a text file or from introspection. + """ + + if not isinstance(scalar, GraphQLScalarType): + raise TypeError("Scalars should be instances of GraphQLScalarType.") + + schema_scalar = schema.get_type(name) + + if schema_scalar is None: + raise KeyError(f"Scalar '{name}' not found in schema.") + + if not isinstance(schema_scalar, GraphQLScalarType): + raise TypeError( + f'The type "{name}" is not a GraphQLScalarType,' + f"it is a {type(schema_scalar)}" + ) + + # Update the conversion methods + # Using setattr because mypy has a false positive + # https://github.com/python/mypy/issues/2427 + setattr(schema_scalar, "serialize", scalar.serialize) + setattr(schema_scalar, "parse_value", scalar.parse_value) + setattr(schema_scalar, "parse_literal", scalar.parse_literal) def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): @@ -11,22 +40,10 @@ def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType """ if not isinstance(scalars, Iterable): - raise GraphQLError("Scalars argument should be a list of scalars.") + raise TypeError("Scalars argument should be a list of scalars.") for scalar in scalars: if not isinstance(scalar, GraphQLScalarType): - raise GraphQLError("Scalars should be instances of GraphQLScalarType.") - - try: - schema_scalar = schema.type_map[scalar.name] - except KeyError: - raise GraphQLError(f"Scalar '{scalar.name}' not found in schema.") - - assert isinstance(schema_scalar, GraphQLScalarType) + raise TypeError("Scalars should be instances of GraphQLScalarType.") - # Update the conversion methods - # Using setattr because mypy has a false positive - # https://github.com/python/mypy/issues/2427 - setattr(schema_scalar, "serialize", scalar.serialize) - setattr(schema_scalar, "parse_value", scalar.parse_value) - setattr(schema_scalar, "parse_literal", scalar.parse_literal) + update_schema_scalar(schema, scalar.name, scalar) diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 7de3e8ac..01a58e02 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -21,7 +21,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportQueryError -from gql.utilities import update_schema_scalars +from gql.utilities import update_schema_scalar, update_schema_scalars from gql.variable_values import serialize_value from ..conftest import MS @@ -623,7 +623,8 @@ async def test_update_schema_scalars(event_loop, aiohttp_server): # Update the schema MoneyScalar default implementation from # introspection with our provided conversion methods - update_schema_scalars(session.client.schema, [MoneyScalar]) + # update_schema_scalars(session.client.schema, [MoneyScalar]) + update_schema_scalar(session.client.schema, "Money", MoneyScalar) query = gql("query myquery($money: Money) {toEuros(money: $money)}") @@ -639,17 +640,24 @@ async def test_update_schema_scalars(event_loop, aiohttp_server): def test_update_schema_scalars_invalid_scalar(): - with pytest.raises(GraphQLError) as exc_info: + with pytest.raises(TypeError) as exc_info: update_schema_scalars(schema, [int]) exception = exc_info.value assert str(exception) == "Scalars should be instances of GraphQLScalarType." + with pytest.raises(TypeError) as exc_info: + update_schema_scalar(schema, "test", int) + + exception = exc_info.value + + assert str(exception) == "Scalars should be instances of GraphQLScalarType." + def test_update_schema_scalars_invalid_scalar_argument(): - with pytest.raises(GraphQLError) as exc_info: + with pytest.raises(TypeError) as exc_info: update_schema_scalars(schema, MoneyScalar) exception = exc_info.value @@ -661,12 +669,24 @@ def test_update_schema_scalars_scalar_not_found_in_schema(): NotFoundScalar = GraphQLScalarType(name="abcd",) - with pytest.raises(GraphQLError) as exc_info: + with pytest.raises(KeyError) as exc_info: update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) exception = exc_info.value - assert str(exception) == "Scalar 'abcd' not found in schema." + assert "Scalar 'abcd' not found in schema." in str(exception) + + +def test_update_schema_scalars_scalar_type_is_not_a_scalar_in_schema(): + + with pytest.raises(TypeError) as exc_info: + update_schema_scalar(schema, "CountriesBalance", MoneyScalar) + + exception = exc_info.value + + assert 'The type "CountriesBalance" is not a GraphQLScalarType, it is a' in str( + exception + ) @pytest.mark.asyncio From ee8dd745d016c2a92928f70fa3d6df770893b2e1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 18:12:48 +0100 Subject: [PATCH 16/21] moving serialize_variable_values into utilities folder --- gql/client.py | 2 +- gql/utilities/__init__.py | 3 +++ .../serialize_variable_values.py} | 4 ++-- gql/utilities/update_schema_scalars.py | 2 +- tests/custom_scalars/test_money.py | 3 +-- 5 files changed, 8 insertions(+), 6 deletions(-) rename gql/{variable_values.py => utilities/serialize_variable_values.py} (97%) diff --git a/gql/client.py b/gql/client.py index c6922ca7..56dd8a0d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -18,7 +18,7 @@ from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport from .utilities import parse_result as parse_result_fn -from .variable_values import serialize_variable_values +from .utilities import serialize_variable_values class Client: diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index 3ebdf81f..d17f9b2d 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,4 +1,5 @@ from .parse_result import parse_result +from .serialize_variable_values import serialize_value, serialize_variable_values from .update_schema_enum import update_schema_enum from .update_schema_scalars import update_schema_scalar, update_schema_scalars @@ -7,4 +8,6 @@ "update_schema_scalar", "update_schema_enum", "parse_result", + "serialize_variable_values", + "serialize_value", ] diff --git a/gql/variable_values.py b/gql/utilities/serialize_variable_values.py similarity index 97% rename from gql/variable_values.py rename to gql/utilities/serialize_variable_values.py index 7db7091a..b5b91cf4 100644 --- a/gql/variable_values.py +++ b/gql/utilities/serialize_variable_values.py @@ -17,7 +17,7 @@ from graphql.pyutils import inspect -def get_document_operation( +def _get_document_operation( document: DocumentNode, operation_name: Optional[str] = None ) -> OperationDefinitionNode: """Returns the operation which should be executed in the document. @@ -99,7 +99,7 @@ def serialize_variable_values( parsed_variable_values: Dict[str, Any] = {} # Find the operation in the document - operation = get_document_operation(document, operation_name=operation_name) + operation = _get_document_operation(document, operation_name=operation_name) # Serialize every variable value defined for the operation for var_def_node in operation.variable_definitions: diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index 8e34e7ad..270fe44f 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -21,7 +21,7 @@ def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalar if not isinstance(schema_scalar, GraphQLScalarType): raise TypeError( f'The type "{name}" is not a GraphQLScalarType,' - f"it is a {type(schema_scalar)}" + f" it is a {type(schema_scalar)}" ) # Update the conversion methods diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 01a58e02..1b65ec98 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -21,8 +21,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportQueryError -from gql.utilities import update_schema_scalar, update_schema_scalars -from gql.variable_values import serialize_value +from gql.utilities import serialize_value, update_schema_scalar, update_schema_scalars from ..conftest import MS From 02dc49e3812da9ec98682b988535deb8e931f047 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 18:33:20 +0100 Subject: [PATCH 17/21] update code documentation for utilities functions --- gql/utilities/parse_result.py | 7 +++++++ gql/utilities/serialize_variable_values.py | 14 +++++++++++++- gql/utilities/update_schema_scalars.py | 11 +++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 112d8858..992e5bb1 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -406,6 +406,13 @@ def parse_result( ) -> Optional[Dict[str, Any]]: """Unserialize a result received from a GraphQL backend. + :param schema: the GraphQL schema + :param document: the document representing the query sent to the backend + :param result: the serialized result received from the backend + + :returns: a parsed result with scalars and enums parsed depending on + their definition in the schema. + Given a schema, a query and a serialized result, provide a new result with parsed values. diff --git a/gql/utilities/serialize_variable_values.py b/gql/utilities/serialize_variable_values.py index b5b91cf4..833df8bd 100644 --- a/gql/utilities/serialize_variable_values.py +++ b/gql/utilities/serialize_variable_values.py @@ -53,7 +53,13 @@ def _get_document_operation( def serialize_value(type_: GraphQLType, value: Any) -> Any: """Given a GraphQL type and a Python value, return the serialized value. + This method will serialize the value recursively, entering into + lists and dicts. + Can be used to serialize Enums and/or Custom Scalars in variable values. + + :param type_: the GraphQL type + :param value: the provided value """ if value is None: @@ -93,7 +99,13 @@ def serialize_variable_values( """Given a GraphQL document and a schema, serialize the Dictionary of variable values. - Useful to serialize Enums and/or Custom Scalars in variable values + Useful to serialize Enums and/or Custom Scalars in variable values. + + :param schema: the GraphQL schema + :param document: the document representing the query sent to the backend + :param variable_values: the dictionnary of variable values which needs + to be serialized. + :param operation_name: the optional operation_name for the query. """ parsed_variable_values: Dict[str, Any] = {} diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index 270fe44f..db3adb17 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -6,6 +6,10 @@ def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): """Update the scalar in a schema with the scalar provided. + :param schema: the GraphQL schema + :param name: the name of the custom scalar type in the schema + :param scalar: a provided scalar type + This can be used to update the default Custom Scalar implementation when the schema has been provided from a text file or from introspection. """ @@ -35,8 +39,15 @@ def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalar def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): """Update the scalars in a schema with the scalars provided. + :param schema: the GraphQL schema + :param scalars: a list of provided scalar types + This can be used to update the default Custom Scalar implementation when the schema has been provided from a text file or from introspection. + + If the name of the provided scalar is different than the name of + the custom scalar, then you should use the + :func:`update_schema_scalar ` method instead. """ if not isinstance(scalars, Iterable): From 6557abe5c5573d224c0cd5608bda2efd35d97a1e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 18:36:10 +0100 Subject: [PATCH 18/21] Doc add link to README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8fefeb2f..a85761e1 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ The main features of GQL are: * Supports GraphQL queries, mutations and [subscriptions](https://gql.readthedocs.io/en/latest/usage/subscriptions.html) * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) +* Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries from the command line * [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically From 7ea8a8c3848678fb5e7c5d4855a9340a54fc460c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 18:46:04 +0100 Subject: [PATCH 19/21] docs fix small typo --- docs/usage/custom_scalars_and_enums.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/usage/custom_scalars_and_enums.rst b/docs/usage/custom_scalars_and_enums.rst index 659b4502..fc9008d8 100644 --- a/docs/usage/custom_scalars_and_enums.rst +++ b/docs/usage/custom_scalars_and_enums.rst @@ -290,7 +290,9 @@ Examples: } # we execute the query with serialize_variables set to True - result = client.execute(query, variable_values=variable_values) + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) Parsing output -------------- From 1e6ff7f19b7f3306ff14ff07149492d6f978ffc6 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 20:04:18 +0100 Subject: [PATCH 20/21] add operation_name to parse_results --- gql/client.py | 19 +++++++++-- gql/utilities/parse_result.py | 31 ++++++++++++++++-- tests/custom_scalars/test_enum_colors.py | 40 ++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/gql/client.py b/gql/client.py index 56dd8a0d..079bb552 100644 --- a/gql/client.py +++ b/gql/client.py @@ -350,7 +350,12 @@ def _execute( # Unserialize the result if requested if self.client.schema: if parse_result or (parse_result is None and self.client.parse_results): - result.data = parse_result_fn(self.client.schema, document, result.data) + result.data = parse_result_fn( + self.client.schema, + document, + result.data, + operation_name=operation_name, + ) return result @@ -495,7 +500,10 @@ async def _subscribe( parse_result is None and self.client.parse_results ): result.data = parse_result_fn( - self.client.schema, document, result.data + self.client.schema, + document, + result.data, + operation_name=operation_name, ) yield result @@ -611,7 +619,12 @@ async def _execute( # Unserialize the result if requested if self.client.schema: if parse_result or (parse_result is None and self.client.parse_results): - result.data = parse_result_fn(self.client.schema, document, result.data) + result.data = parse_result_fn( + self.client.schema, + document, + result.data, + operation_name=operation_name, + ) return result diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 992e5bb1..badbfb0b 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -16,6 +16,7 @@ GraphQLSchema, GraphQLType, InlineFragmentNode, + NameNode, Node, OperationDefinitionNode, SelectionSetNode, @@ -71,6 +72,7 @@ def __init__( type_info: TypeInfo, visit_fragment: bool = False, inside_list_level: int = 0, + operation_name: Optional[str] = None, ): """Recursive Implementation of a Visitor class to parse results correspondind to a schema and a document. @@ -96,6 +98,7 @@ def __init__( self.type_info: TypeInfo = type_info self.visit_fragment: bool = visit_fragment self.inside_list_level = inside_list_level + self.operation_name = operation_name self.result_stack: List[Any] = [] @@ -111,6 +114,22 @@ def leave_document(node: DocumentNode, *_args: Any) -> Dict[str, Any]: results = cast(List[Dict[str, Any]], node.definitions) return {k: v for result in results for k, v in result.items()} + def enter_operation_definition( + self, node: OperationDefinitionNode, *_args: Any + ) -> Union[None, VisitorActionEnum]: + + if self.operation_name is not None: + if not hasattr(node.name, "value"): + return REMOVE # pragma: no cover + + node.name = cast(NameNode, node.name) + + if node.name.value != self.operation_name: + log.debug(f"SKIPPING operation {node.name.value}") + return REMOVE + + return IDLE + @staticmethod def leave_operation_definition( node: OperationDefinitionNode, *_args: Any @@ -374,6 +393,7 @@ def parse_result_recursive( initial_type: Optional[GraphQLType] = None, inside_list_level: int = 0, visit_fragment: bool = False, + operation_name: Optional[str] = None, ) -> Any: if result is None: @@ -393,6 +413,7 @@ def parse_result_recursive( type_info=type_info, inside_list_level=inside_list_level, visit_fragment=visit_fragment, + operation_name=operation_name, ), ), visitor_keys=RESULT_DOCUMENT_KEYS, @@ -402,13 +423,17 @@ def parse_result_recursive( def parse_result( - schema: GraphQLSchema, document: DocumentNode, result: Optional[Dict[str, Any]], + schema: GraphQLSchema, + document: DocumentNode, + result: Optional[Dict[str, Any]], + operation_name: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Unserialize a result received from a GraphQL backend. :param schema: the GraphQL schema :param document: the document representing the query sent to the backend :param result: the serialized result received from the backend + :param operation_name: the optional operation name :returns: a parsed result with scalars and enums parsed depending on their definition in the schema. @@ -423,4 +448,6 @@ def parse_result( will be parsed with the parse_value method of the custom scalar or enum definition in the schema.""" - return parse_result_recursive(schema, document, document, result) + return parse_result_recursive( + schema, document, document, result, operation_name=operation_name + ) diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index f76a581c..2c7b887c 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -283,3 +283,43 @@ def test_update_schema_enum_errors(): update_schema_enum(schema, "Color", {"RED": Color.RED}) assert 'Enum key "GREEN" not found in provided values!' in str(exc_info) + + +def test_parse_results_with_operation_type(): + + client = Client(schema=schema, parse_results=True) + + query = gql( + """ + query GetAll { + all + } + query GetOppositeColor($color: Color) { + opposite(color:$color) + } + query GetOppositeColor2($color: Color) { + other_opposite:opposite(color:$color) + } + query GetOppositeColor3 { + opposite(color: YELLOW) + } + query GetListOfListOfList { + list_of_list_of_list + } + """ + ) + + variable_values = { + "color": "RED", + } + + result = client.execute( + query, variable_values=variable_values, operation_name="GetOppositeColor" + ) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN From 3442a32bfab18db9ac3a7030afe14c192ffdf0e2 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 20:27:46 +0100 Subject: [PATCH 21/21] Remove obsolete comment --- gql/utilities/parse_result.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index badbfb0b..ecb73474 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -212,13 +212,6 @@ def enter_field( and not is_leaf_type(result_type) ): - """ - if not isinstance(result_type, GraphQLList): - raise TypeError( - f"Received iterable result for a non-list type: {result_value}" - ) - """ - # Finding out the inner type of the list inner_type = _ignore_non_null(result_type.of_type)