From 97cd612bdd89c45e400ffaa1afc1697c344e653e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 12 Nov 2021 15:16:03 +0100 Subject: [PATCH 1/5] Implementation of serialization of variable values --- gql/client.py | 199 ++++++- gql/variable_values.py | 117 ++++ tests/custom_scalars/__init__.py | 0 .../test_custom_scalar_datetime.py | 174 ++++++ .../test_custom_scalar_money.py | 527 ++++++++++++++++++ 5 files changed, 1005 insertions(+), 12 deletions(-) create mode 100644 gql/variable_values.py create mode 100644 tests/custom_scalars/__init__.py create mode 100644 tests/custom_scalars/test_custom_scalar_datetime.py create mode 100644 tests/custom_scalars/test_custom_scalar_money.py diff --git a/gql/client.py b/gql/client.py index 6017ab69..368193cc 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,6 +17,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .variable_values import serialize_variable_values class Client: @@ -289,18 +290,79 @@ def __init__(self, client: Client): """:param client: the :class:`client ` used""" self.client = client - def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def _execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> ExecutionResult: + """Execute the provided document AST synchronously using + the sync transport, returning an ExecutionResult object. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) - return self.transport.execute(document, *args, **kwargs) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + + return self.transport.execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: + """Execute the provided document AST synchronously using + the sync transport. + + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = self._execute(document, *args, **kwargs) + result = self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: @@ -341,17 +403,52 @@ def __init__(self, client: Client): self.client = client async def _subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: + """Coroutine to subscribe asynchronously to the provided document AST + asynchronously using the async transport, + returning an async generator producing ExecutionResult objects. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport subscribe method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Subscribe to the transport inner_generator: AsyncGenerator[ ExecutionResult, None - ] = self.transport.subscribe(document, *args, **kwargs) + ] = self.transport.subscribe( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) # Keep a reference to the inner generator to allow the user to call aclose() # before a break if python version is too old (pypy3 py 3.6.1) @@ -364,15 +461,35 @@ async def _subscribe( await inner_generator.aclose() async def subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[Dict, None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport subscribe method.""" inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, *args, **kwargs + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, ) try: @@ -391,27 +508,85 @@ async def subscribe( await inner_generator.aclose() async def _execute( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> ExecutionResult: + """Coroutine to execute the provided document AST asynchronously using + the async transport, returning an ExecutionResult object. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Execute the query with the transport with a timeout return await asyncio.wait_for( - self.transport.execute(document, *args, **kwargs), + self.transport.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + *args, + **kwargs, + ), self.client.execute_timeout, ) - async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + async def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: """Coroutine to execute the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = await self._execute(document, *args, **kwargs) + result = await self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: diff --git a/gql/variable_values.py b/gql/variable_values.py new file mode 100644 index 00000000..7db7091a --- /dev/null +++ b/gql/variable_values.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, Optional + +from graphql import ( + DocumentNode, + GraphQLEnumType, + GraphQLError, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + GraphQLScalarType, + GraphQLSchema, + GraphQLType, + GraphQLWrappingType, + OperationDefinitionNode, + type_from_ast, +) +from graphql.pyutils import inspect + + +def get_document_operation( + document: DocumentNode, operation_name: Optional[str] = None +) -> OperationDefinitionNode: + """Returns the operation which should be executed in the document. + + Raises a GraphQLError if a single operation cannot be retrieved. + """ + + operation: Optional[OperationDefinitionNode] = None + + for definition in document.definitions: + if isinstance(definition, OperationDefinitionNode): + if operation_name is None: + if operation: + raise GraphQLError( + "Must provide operation name" + " if query contains multiple operations." + ) + operation = definition + elif definition.name and definition.name.value == operation_name: + operation = definition + + if not operation: + if operation_name is not None: + raise GraphQLError(f"Unknown operation named '{operation_name}'.") + + # The following line should never happen normally as the document is + # already verified before calling this function. + raise GraphQLError("Must provide an operation.") # pragma: no cover + + return operation + + +def serialize_value(type_: GraphQLType, value: Any) -> Any: + """Given a GraphQL type and a Python value, return the serialized value. + + Can be used to serialize Enums and/or Custom Scalars in variable values. + """ + + if value is None: + if isinstance(type_, GraphQLNonNull): + # raise GraphQLError(f"Type {type_.of_type.name} Cannot be None.") + raise GraphQLError(f"Type {inspect(type_)} Cannot be None.") + else: + return None + + if isinstance(type_, GraphQLWrappingType): + inner_type = type_.of_type + + if isinstance(type_, GraphQLNonNull): + return serialize_value(inner_type, value) + + elif isinstance(type_, GraphQLList): + return [serialize_value(inner_type, v) for v in value] + + elif isinstance(type_, (GraphQLScalarType, GraphQLEnumType)): + return type_.serialize(value) + + elif isinstance(type_, GraphQLInputObjectType): + return { + field_name: serialize_value(field.type, value[field_name]) + for field_name, field in type_.fields.items() + } + + raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.") + + +def serialize_variable_values( + schema: GraphQLSchema, + document: DocumentNode, + variable_values: Dict[str, Any], + operation_name: Optional[str] = None, +) -> Dict[str, Any]: + """Given a GraphQL document and a schema, serialize the Dictionary of + variable values. + + Useful to serialize Enums and/or Custom Scalars in variable values + """ + + parsed_variable_values: Dict[str, Any] = {} + + # Find the operation in the document + operation = get_document_operation(document, operation_name=operation_name) + + # Serialize every variable value defined for the operation + for var_def_node in operation.variable_definitions: + var_name = var_def_node.variable.name.value + var_type = type_from_ast(schema, var_def_node.type) + + if var_name in variable_values: + + assert var_type is not None + + var_value = variable_values[var_name] + + parsed_variable_values[var_name] = serialize_value(var_type, var_value) + + return parsed_variable_values diff --git a/tests/custom_scalars/__init__.py b/tests/custom_scalars/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py new file mode 100644 index 00000000..22b3987a --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -0,0 +1,174 @@ +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +import pytest +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInt, + GraphQLList, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +def serialize_datetime(value: Any) -> str: + if not isinstance(value, datetime): + raise GraphQLError("Cannot serialize datetime value: " + inspect(value)) + return value.isoformat() + + +def parse_datetime_value(value: Any) -> datetime: + + if isinstance(value, str): + try: + # Note: a more solid custom scalar should use dateutil.parser.isoparse + # Not using it here in the test to avoid adding another dependency + return datetime.fromisoformat(value) + except Exception: + raise GraphQLError("Cannot parse datetime value : " + inspect(value)) + + else: + raise GraphQLError("Cannot parseee datetime value: " + inspect(value)) + + +def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + if not isinstance(ast_value, str): + raise GraphQLError("Cannot parse literal datetime value: " + inspect(ast_value)) + + return parse_datetime_value(ast_value) + + +DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, +) + + +def resolve_shift_days(root, _info, time, days): + return time + timedelta(days=days) + + +def resolve_latest(root, _info, times): + return max(times) + + +def resolve_seconds(root, _info, interval): + print(f"interval={interval!r}") + return (interval["end"] - interval["start"]).total_seconds() + + +IntervalInputType = GraphQLInputObjectType( + "IntervalInput", + fields={ + "start": GraphQLInputField(DatetimeScalar), + "end": GraphQLInputField(DatetimeScalar), + }, +) + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "shiftDays": GraphQLField( + DatetimeScalar, + args={ + "time": GraphQLArgument(DatetimeScalar), + "days": GraphQLArgument(GraphQLInt), + }, + resolve=resolve_shift_days, + ), + "latest": GraphQLField( + DatetimeScalar, + args={"times": GraphQLArgument(GraphQLList(DatetimeScalar))}, + resolve=resolve_latest, + ), + "seconds": GraphQLField( + GraphQLInt, + args={"interval": GraphQLArgument(IntervalInputType)}, + resolve=resolve_seconds, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +def test_shift_days(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": now, + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +def test_latest(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql("query latest($times: [Datetime!]!) {latest(times: $times)}") + + variable_values = { + "times": [now, in_five_days], + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["latest"] == in_five_days.isoformat() + + +def test_seconds(): + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql( + "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" + ) + + variable_values = {"interval": {"start": now, "end": in_five_days}} + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["seconds"] == 432000 diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py new file mode 100644 index 00000000..6560a3eb --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -0,0 +1,527 @@ +import asyncio +from typing import Any, Dict, NamedTuple, Optional + +import pytest +from graphql import graphql_sync +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect, is_finite +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.variable_values import serialize_value + +from ..conftest import MS + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +class Money(NamedTuple): + amount: float + currency: str + + +def serialize_money(output_value: Any) -> Dict[str, Any]: + if not isinstance(output_value, Money): + raise GraphQLError("Cannot serialize money value: " + inspect(output_value)) + return output_value._asdict() + + +def parse_money_value(input_value: Any) -> Money: + """Using Money custom scalar from graphql-core tests except here the + input value is supposed to be a dict instead of a Money object.""" + + """ + if isinstance(input_value, Money): + return input_value + """ + + if isinstance(input_value, dict): + amount = input_value.get("amount", None) + currency = input_value.get("currency", None) + + if not is_finite(amount) or not isinstance(currency, str): + raise GraphQLError("Cannot parse money value dict: " + inspect(input_value)) + + return Money(float(amount), currency) + else: + raise GraphQLError("Cannot parse money value: " + inspect(input_value)) + + +def parse_money_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Money: + money = value_from_ast_untyped(value_node, variables) + if variables is not None and ( + # variables are not set when checked with ValuesIOfCorrectTypeRule + not money + or not is_finite(money.get("amount")) + or not isinstance(money.get("currency"), str) + ): + raise GraphQLError("Cannot parse literal money value: " + inspect(money)) + return Money(**money) + + +MoneyScalar = GraphQLScalarType( + name="Money", + serialize=serialize_money, + parse_value=parse_money_value, + parse_literal=parse_money_literal, +) + + +def resolve_balance(root, _info): + return root + + +def resolve_to_euros(_root, _info, money): + amount = money.amount + currency = money.currency + if not amount or currency == "EUR": + return amount + if currency == "DM": + return amount * 0.5 + raise ValueError("Cannot convert to euros: " + inspect(money)) + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "balance": GraphQLField(MoneyScalar, resolve=resolve_balance), + "toEuros": GraphQLField( + GraphQLFloat, + args={"money": GraphQLArgument(MoneyScalar)}, + resolve=resolve_to_euros, + ), + }, +) + + +def resolve_spent_money(spent_money, _info, **kwargs): + return spent_money + + +async def subscribe_spend_all(_root, _info, money): + while money.amount > 0: + money = Money(money.amount - 1, money.currency) + yield money + await asyncio.sleep(1 * MS) + + +subscriptionType = GraphQLObjectType( + "Subscription", + fields=lambda: { + "spend": GraphQLField( + MoneyScalar, + args={"money": GraphQLArgument(MoneyScalar)}, + subscribe=subscribe_spend_all, + resolve=resolve_spent_money, + ) + }, +) + +root_value = Money(42, "DM") + +schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) + + +def test_custom_scalar_in_output(): + + client = Client(schema=schema) + + query = gql("{balance}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +def test_custom_scalar_in_input_query(): + + client = Client(schema=schema) + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 10 + + +def test_custom_scalar_in_input_variable_values(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + + variable_values = {"money": money_value} + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized_with_operation_name(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="myquery", + ) + + assert result["toEuros"] == 5 + + +def test_serialize_variable_values_exception_multiple_ops_without_operation_name(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } + + query mybalance { + balance + }""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + exception = exc_info.value + + assert ( + str(exception) + == "Must provide operation name if query contains multiple operations." + ) + + +def test_serialize_variable_values_exception_operation_name_not_found(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } +""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="invalid_operation_name", + ) + + exception = exc_info.value + + assert str(exception) == "Unknown operation named 'invalid_operation_name'." + + +def test_custom_scalar_subscribe_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("subscription spendAll($money: Money) {spend(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + expected_result = {"spend": {"amount": 10, "currency": "DM"}} + + for result in client.subscribe( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ): + print(f"result = {result!r}") + expected_result["spend"]["amount"] = expected_result["spend"]["amount"] - 1 + assert expected_result == result + + +async def make_money_backend(aiohttp_server): + from aiohttp import web + + async def handler(request): + data = await request.json() + source = data["query"] + + print(f"data keys = {data.keys()}") + try: + variables = data["variables"] + print(f"variables = {variables!r}") + except KeyError: + variables = None + + result = graphql_sync( + schema, source, variable_values=variables, root_value=root_value + ) + + print(f"backend result = {result!r}") + + return web.json_response( + { + "data": result.data, + "errors": [str(e) for e in result.errors] if result.errors else None, + } + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + return server + + +async def make_money_transport(aiohttp_server): + from gql.transport.aiohttp import AIOHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + return transport + + +async def make_sync_money_transport(aiohttp_server): + from gql.transport.requests import RequestsHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = RequestsHTTPTransport(url=url, timeout=10) + + return (server, transport) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("{balance}") + + result = await session.execute(query) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 10 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + # money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_split_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql( + """ +query myquery($amount: Float, $currency: String) { + toEuros(money: {amount: $amount, currency: $currency}) +}""" + ) + + variable_values = {"amount": 10, "currency": "DM"} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +@pytest.mark.requests +async def test_custom_scalar_serialize_variables_sync_transport( + event_loop, aiohttp_server, run_sync_test +): + + server, transport = await make_sync_money_transport(aiohttp_server) + + def test_code(): + with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + await run_sync_test(event_loop, server, test_code) + + +def test_serialize_value_with_invalid_type(): + + with pytest.raises(GraphQLError) as exc_info: + serialize_value("Not a valid type", 50) + + exception = exc_info.value + + assert ( + str(exception) == "Impossible to serialize value with type: 'Not a valid type'." + ) + + +def test_serialize_value_with_non_null_type_null(): + + non_null_int = GraphQLNonNull(GraphQLInt) + + with pytest.raises(GraphQLError) as exc_info: + serialize_value(non_null_int, None) + + exception = exc_info.value + + assert str(exception) == "Type Int! Cannot be None." + + +def test_serialize_value_with_nullable_type(): + + nullable_int = GraphQLInt + + assert serialize_value(nullable_int, None) is None From 2b3db148fbe02d4b04c3f00f571c1d27bb7feb23 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 12 Nov 2021 15:24:51 +0100 Subject: [PATCH 2/5] Skip datetime tests on Python 3.6 --- tests/custom_scalars/test_custom_scalar_datetime.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py index 22b3987a..058b269e 100644 --- a/tests/custom_scalars/test_custom_scalar_datetime.py +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -20,9 +20,6 @@ from gql import Client, gql -# Marking all tests in this file with the aiohttp marker -pytestmark = pytest.mark.aiohttp - def serialize_datetime(value: Any) -> str: if not isinstance(value, datetime): @@ -110,6 +107,9 @@ def resolve_seconds(root, _info, interval): schema = GraphQLSchema(query=queryType) +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) def test_shift_days(): client = Client(schema=schema) @@ -131,6 +131,9 @@ def test_shift_days(): assert result["shiftDays"] == "2021-11-17T11:58:13.461161" +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) def test_latest(): client = Client(schema=schema) @@ -153,6 +156,9 @@ def test_latest(): assert result["latest"] == in_five_days.isoformat() +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) def test_seconds(): client = Client(schema=schema) From c251a2f9206b4a29b4a8404c08574b7c5c7f2993 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 12 Nov 2021 19:49:14 +0100 Subject: [PATCH 3/5] Add update_schema_scalars function This function allows us to update a schema from a file or from introspection with custom scalars implementations --- gql/utilities/__init__.py | 5 + gql/utilities/update_schema_scalars.py | 32 ++++++ .../test_custom_scalar_money.py | 108 ++++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 gql/utilities/__init__.py create mode 100644 gql/utilities/update_schema_scalars.py diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py new file mode 100644 index 00000000..68b80156 --- /dev/null +++ b/gql/utilities/__init__.py @@ -0,0 +1,5 @@ +from .update_schema_scalars import update_schema_scalars + +__all__ = [ + "update_schema_scalars", +] diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py new file mode 100644 index 00000000..d5434c6b --- /dev/null +++ b/gql/utilities/update_schema_scalars.py @@ -0,0 +1,32 @@ +from typing import Iterable, List + +from graphql import GraphQLError, GraphQLScalarType, GraphQLSchema + + +def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): + """Update the scalars in a schema with the scalars provided. + + This can be used to update the default Custom Scalar implementation + when the schema has been provided from a text file or from introspection. + """ + + if not isinstance(scalars, Iterable): + raise GraphQLError("Scalars argument should be a list of scalars.") + + for scalar in scalars: + if not isinstance(scalar, GraphQLScalarType): + raise GraphQLError("Scalars should be instances of GraphQLScalarType.") + + try: + schema_scalar = schema.type_map[scalar.name] + except KeyError: + raise GraphQLError(f"Scalar '{scalar.name}' not found in schema.") + + assert isinstance(schema_scalar, GraphQLScalarType) + + # Update the conversion methods + # Using setattr because mypy has a false positive + # https://github.com/python/mypy/issues/2427 + setattr(schema_scalar, "serialize", scalar.serialize) + setattr(schema_scalar, "parse_value", scalar.parse_value) + setattr(schema_scalar, "parse_literal", scalar.parse_literal) diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py index 6560a3eb..238308a9 100644 --- a/tests/custom_scalars/test_custom_scalar_money.py +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -19,6 +19,8 @@ from graphql.utilities import value_from_ast_untyped from gql import Client, gql +from gql.transport.exceptions import TransportQueryError +from gql.utilities import update_schema_scalars from gql.variable_values import serialize_value from ..conftest import MS @@ -471,6 +473,112 @@ async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): assert result["toEuros"] == 5 +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + with pytest.raises(TransportQueryError): + await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_schema_from_introspection( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + schema = session.client.schema + + # Updating the Money Scalar in the schema + # We cannot replace it because some other objects keep a reference + # to the existing Scalar + # cannot do: schema.type_map["Money"] = MoneyScalar + + money_scalar = schema.type_map["Money"] + + money_scalar.serialize = MoneyScalar.serialize + money_scalar.parse_value = MoneyScalar.parse_value + money_scalar.parse_literal = MoneyScalar.parse_literal + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_update_schema_scalars(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # Update the schema MoneyScalar default implementation from + # introspection with our provided conversion methods + update_schema_scalars(session.client.schema, [MoneyScalar]) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +def test_update_schema_scalars_invalid_scalar(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [int]) + + exception = exc_info.value + + assert str(exception) == "Scalars should be instances of GraphQLScalarType." + + +def test_update_schema_scalars_invalid_scalar_argument(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, MoneyScalar) + + exception = exc_info.value + + assert str(exception) == "Scalars argument should be a list of scalars." + + +def test_update_schema_scalars_scalar_not_found_in_schema(): + + NotFoundScalar = GraphQLScalarType(name="abcd",) + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) + + exception = exc_info.value + + assert str(exception) == "Scalar 'abcd' not found in schema." + + @pytest.mark.asyncio @pytest.mark.requests async def test_custom_scalar_serialize_variables_sync_transport( From 104788500e8d65e64ca7474d88304368348425ae Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 13 Nov 2021 13:56:47 +0100 Subject: [PATCH 4/5] Adding documentation --- docs/modules/gql.rst | 1 + docs/modules/utilities.rst | 6 + docs/usage/custom_scalars.rst | 134 ++++++++++++++++++ docs/usage/index.rst | 1 + .../test_custom_scalar_datetime.py | 42 +++++- 5 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 docs/modules/utilities.rst create mode 100644 docs/usage/custom_scalars.rst diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index aac47c86..06a89a96 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,3 +21,4 @@ Sub-Packages client transport dsl + utilities diff --git a/docs/modules/utilities.rst b/docs/modules/utilities.rst new file mode 100644 index 00000000..47043b98 --- /dev/null +++ b/docs/modules/utilities.rst @@ -0,0 +1,6 @@ +gql.utilities +============= + +.. currentmodule:: gql.utilities + +.. automodule:: gql.utilities diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst new file mode 100644 index 00000000..baee441e --- /dev/null +++ b/docs/usage/custom_scalars.rst @@ -0,0 +1,134 @@ +Custom Scalars +============== + +Scalar types represent primitive values at the leaves of a query. + +GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend +can add additional custom scalars to its schema to better express values in their data model. + +For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. + +The schema will then only contain: + +.. code-block:: python + + scalar Datetime + +When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), +their values need to be serialized to be composed +of only built-in scalars, then at the destination the serialized values will be parsed again to +be able to represent the scalar in its local internal representation. + +Because this serialization/unserialization is dependent on the language used at both sides, it is not +described in the schema and needs to be defined independently at both sides (client, backend). + +A custom scalar value can have two different representations during its transport: + + - as a serialized value (usually as json): + + * in the results sent by the backend + * in the variables sent by the client alongside the query + + - as "literal" inside the query itself sent by the client + +To define a custom scalar, you need 3 methods: + + - a :code:`serialize` method used: + + * by the backend to serialize a custom scalar output in the result + * by the client to serialize a custom scalar input in the variables + + - a :code:`parse_value` method used: + + * by the backend to unserialize custom scalars inputs in the variables sent by the client + * by the client to unserialize custom scalars outputs from the results + + - a :code:`parse_literal` method used: + + * by the backend to unserialize custom scalars inputs inside the query itself + +To define a custom scalar object, we define a :code:`GraphQLScalarType` from graphql-core with +its name and the implementation of the above methods. + +Example for Datetime: + +.. code-block:: python + + from datetime import datetime + from typing import Any, Dict, Optional + + from graphql import GraphQLScalarType, ValueNode + from graphql.utilities import value_from_ast_untyped + + + def serialize_datetime(value: Any) -> str: + return value.isoformat() + + + def parse_datetime_value(value: Any) -> datetime: + return datetime.fromisoformat(value) + + + def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None + ) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + return parse_datetime_value(ast_value) + + + DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, + ) + +Custom Scalars in inputs +------------------------ + +To provide custom scalars in input with gql, you can: + +- serialize the scalar yourself as "literal" in the query: + +.. code-block:: python + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + +- serialize the scalar yourself in a variable: + +.. code-block:: python + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + +- add a custom scalar to the schema with :func:`update_schema_scalars ` + and execute the query with :code:`serialize_variables=True` + and gql will serialize the variable values from a Python object representation. + +For this, you need to provide a schema or set :code:`fetch_schema_from_transport=True` +in the client to request the schema from the backend. + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) diff --git a/docs/usage/index.rst b/docs/usage/index.rst index a7dd4d56..4a38093a 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -10,3 +10,4 @@ Usage variables headers file_upload + custom_scalars diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py index 058b269e..25c6bb31 100644 --- a/tests/custom_scalars/test_custom_scalar_datetime.py +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -38,7 +38,7 @@ def parse_datetime_value(value: Any) -> datetime: raise GraphQLError("Cannot parse datetime value : " + inspect(value)) else: - raise GraphQLError("Cannot parseee datetime value: " + inspect(value)) + raise GraphQLError("Cannot parse datetime value: " + inspect(value)) def parse_datetime_literal( @@ -131,6 +131,46 @@ def test_shift_days(): assert result["shiftDays"] == "2021-11-17T11:58:13.461161" +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_query(): + + client = Client(schema=schema) + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + + result = client.execute(query) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_variables(): + + client = Client(schema=schema) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + @pytest.mark.skipif( not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" ) From f484ae224d161240d3f689f3a5e055374492cbad Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 15 Nov 2021 17:33:59 +0100 Subject: [PATCH 5/5] dsl_serialize_to_literals --- gql/dsl.py | 102 +++++++- .../custom_scalars/test_custom_scalar_json.py | 241 ++++++++++++++++++ tests/starwars/test_dsl.py | 33 ++- 3 files changed, 368 insertions(+), 8 deletions(-) create mode 100644 tests/custom_scalars/test_custom_scalar_json.py diff --git a/gql/dsl.py b/gql/dsl.py index f3bd1fe2..1646d402 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,15 +1,22 @@ import logging +import re from abc import ABC, abstractmethod +from math import isfinite from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, + BooleanValueNode, DocumentNode, + EnumValueNode, FieldNode, + FloatValueNode, FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLError, GraphQLField, + GraphQLID, GraphQLInputObjectType, GraphQLInputType, GraphQLInterfaceType, @@ -20,6 +27,7 @@ GraphQLSchema, GraphQLWrappingType, InlineFragmentNode, + IntValueNode, ListTypeNode, ListValueNode, NamedTypeNode, @@ -31,25 +39,76 @@ OperationDefinitionNode, OperationType, SelectionSetNode, + StringValueNode, TypeNode, Undefined, ValueNode, VariableDefinitionNode, VariableNode, assert_named_type, + is_enum_type, is_input_object_type, + is_leaf_type, is_list_type, is_non_null_type, is_wrapping_type, print_ast, ) -from graphql.pyutils import FrozenList -from graphql.utilities import ast_from_value as default_ast_from_value +from graphql.pyutils import FrozenList, inspect from .utils import to_camel_case log = logging.getLogger(__name__) +_re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$") + + +def ast_from_serialized_value_untyped(serialized: Any) -> Optional[ValueNode]: + """Given a serialized value, try our best to produce an AST. + + Anything ressembling an array (instance of Mapping) will be converted + to an ObjectFieldNode. + + Anything ressembling a list (instance of Iterable - except str) + will be converted to a ListNode. + + In some cases, a custom scalar can be serialized differently in the query + than in the variables. In that case, this function will not work.""" + + if serialized is None or serialized is Undefined: + return NullValueNode() + + if isinstance(serialized, Mapping): + field_items = ( + (key, ast_from_serialized_value_untyped(value)) + for key, value in serialized.items() + ) + field_nodes = ( + ObjectFieldNode(name=NameNode(value=field_name), value=field_value) + for field_name, field_value in field_items + if field_value + ) + return ObjectValueNode(fields=FrozenList(field_nodes)) + + if isinstance(serialized, Iterable) and not isinstance(serialized, str): + maybe_nodes = (ast_from_serialized_value_untyped(item) for item in serialized) + nodes = filter(None, maybe_nodes) + return ListValueNode(values=FrozenList(nodes)) + + if isinstance(serialized, bool): + return BooleanValueNode(value=serialized) + + if isinstance(serialized, int): + return IntValueNode(value=f"{serialized:d}") + + if isinstance(serialized, float) and isfinite(serialized): + return FloatValueNode(value=f"{serialized:g}") + + if isinstance(serialized, str): + return StringValueNode(value=serialized) + + raise TypeError(f"Cannot convert value to AST: {inspect(serialized)}.") + def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: """ @@ -60,15 +119,21 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: VariableNode when value is a DSLVariable Produce a GraphQL Value AST given a Python object. + + Raises a GraphQLError instead of returning None if we receive an Undefined + of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): return value.set_type(type_).ast_variable if is_non_null_type(type_): type_ = cast(GraphQLNonNull, type_) - ast_value = ast_from_value(value, type_.of_type) + inner_type = type_.of_type + ast_value = ast_from_value(value, inner_type) if isinstance(ast_value, NullValueNode): - return None + raise GraphQLError( + "Received Null value for a Non-Null type " f"{inspect(inner_type)}." + ) return ast_value # only explicit None, not Undefined or NaN @@ -77,7 +142,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: # undefined if value is Undefined: - return None + raise GraphQLError(f"Received Undefined value for type {inspect(type_)}.") # Convert Python list to GraphQL list. If the GraphQLType is a list, but the value # is not a list, convert the value using the list's item type. @@ -108,7 +173,32 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: ) return ObjectValueNode(fields=FrozenList(field_nodes)) - return default_ast_from_value(value, type_) + if is_leaf_type(type_): + # Since value is an internally represented value, it must be serialized to an + # externally represented value before converting into an AST. + serialized = type_.serialize(value) # type: ignore + + # if the serialized value is a string, then we should use the + # type to determine if it is an enum, an ID or a normal string + if isinstance(serialized, str): + # Enum types use Enum literals. + if is_enum_type(type_): + return EnumValueNode(value=serialized) + + # ID types can use Int literals. + if type_ is GraphQLID and _re_integer_string.match(serialized): + return IntValueNode(value=serialized) + + return StringValueNode(value=serialized) + + # Some custom scalars will serialize to dicts or lists + # Providing here a default conversion to AST using our best judgment + # until graphql-js issue #1817 is solved + # https://github.com/graphql/graphql-js/issues/1817 + return ast_from_serialized_value_untyped(serialized) + + # Not reachable. All possible input types have been considered. + raise TypeError(f"Unexpected input type: {inspect(type_)}.") def dsl_gql( diff --git a/tests/custom_scalars/test_custom_scalar_json.py b/tests/custom_scalars/test_custom_scalar_json.py new file mode 100644 index 00000000..80f99850 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_json.py @@ -0,0 +1,241 @@ +from typing import Any, Dict, Optional + +import pytest +from graphql import ( + GraphQLArgument, + GraphQLError, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.language import ValueNode +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.dsl import DSLSchema + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +def serialize_json(value: Any) -> Dict[str, Any]: + return value + + +def parse_json_value(value: Any) -> Any: + return value + + +def parse_json_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Any: + return value_from_ast_untyped(value_node, variables) + + +JsonScalar = GraphQLScalarType( + name="JSON", + serialize=serialize_json, + parse_value=parse_json_value, + parse_literal=parse_json_literal, +) + +root_value = { + "players": [ + { + "name": "John", + "level": 3, + "is_connected": True, + "score": 123.45, + "friends": ["Alex", "Alicia"], + }, + { + "name": "Alex", + "level": 4, + "is_connected": False, + "score": 1337.69, + "friends": None, + }, + ] +} + + +def resolve_players(root, _info): + return root["players"] + + +queryType = GraphQLObjectType( + name="Query", fields={"players": GraphQLField(JsonScalar, resolve=resolve_players)}, +) + + +def resolve_add_player(root, _info, player): + print(f"player = {player!r}") + root["players"].append(player) + return {"players": root["players"]} + + +mutationType = GraphQLObjectType( + name="Mutation", + fields={ + "addPlayer": GraphQLField( + JsonScalar, + args={"player": GraphQLArgument(GraphQLNonNull(JsonScalar))}, + resolve=resolve_add_player, + ) + }, +) + +schema = GraphQLSchema(query=queryType, mutation=mutationType) + + +def test_json_value_output(): + + client = Client(schema=schema) + + query = gql("query {players}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["players"] == serialize_json(root_value["players"]) + + +def test_json_value_input_in_ast(): + + client = Client(schema=schema) + + query = gql( + """ + mutation adding_player { + addPlayer(player: { + name: "Tom", + level: 1, + is_connected: True, + score: 0, + friends: [ + "John" + ] + }) +}""" + ) + + result = client.execute(query, root_value=root_value) + + print(result) + + players = result["addPlayer"]["players"] + + assert players == serialize_json(root_value["players"]) + assert players[-1]["name"] == "Tom" + + +def test_json_value_input_in_ast_with_variables(): + + print(f"{schema.type_map!r}") + client = Client(schema=schema) + + # Note: we need to manually add the built-in types which + # are not present in the schema + schema.type_map["Int"] = GraphQLInt + schema.type_map["Float"] = GraphQLFloat + + query = gql( + """ + mutation adding_player( + $name: String!, + $level: Int!, + $is_connected: Boolean, + $score: Float!, + $friends: [String!]!) { + + addPlayer(player: { + name: $name, + level: $level, + is_connected: $is_connected, + score: $score, + friends: $friends, + }) +}""" + ) + + variable_values = { + "name": "Barbara", + "level": 1, + "is_connected": False, + "score": 69, + "friends": ["Alex", "John"], + } + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + print(result) + + players = result["addPlayer"]["players"] + + assert players == serialize_json(root_value["players"]) + assert players[-1]["name"] == "Barbara" + + +def test_json_value_input_in_dsl_argument(): + + ds = DSLSchema(schema) + + new_player = { + "name": "Tim", + "level": 0, + "is_connected": False, + "score": 5, + "friends": ["Lea"], + } + + query = ds.Mutation.addPlayer(player=new_player) + + print(str(query)) + + assert ( + str(query) + == """addPlayer( + player: {name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"]} +)""" + ) + + +def test_none_json_value_input_in_dsl_argument(): + + ds = DSLSchema(schema) + + with pytest.raises(GraphQLError) as exc_info: + ds.Mutation.addPlayer(player=None) + + assert "Received Null value for a Non-Null type JSON." in str(exc_info.value) + + +def test_json_value_input_with_none_list_in_dsl_argument(): + + ds = DSLSchema(schema) + + new_player = { + "name": "Bob", + "level": 9001, + "is_connected": True, + "score": 666.66, + "friends": None, + } + + query = ds.Mutation.addPlayer(player=new_player) + + print(str(query)) + + assert ( + str(query) + == """addPlayer( + player: {name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null} +)""" + ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 93de6c03..d18bb37d 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,5 +1,7 @@ import pytest from graphql import ( + GraphQLError, + GraphQLID, GraphQLInt, GraphQLList, GraphQLNonNull, @@ -23,6 +25,7 @@ DSLSubscription, DSLVariable, DSLVariableDefinitions, + ast_from_serialized_value_untyped, ast_from_value, dsl_gql, ) @@ -54,12 +57,38 @@ def test_ast_from_value_with_none(): def test_ast_from_value_with_undefined(): - assert ast_from_value(Undefined, GraphQLInt) is None + with pytest.raises(GraphQLError) as exc_info: + ast_from_value(Undefined, GraphQLInt) + + assert "Received Undefined value for type Int." in str(exc_info.value) + + +def test_ast_from_value_with_graphqlid(): + + assert ast_from_value("12345", GraphQLID) == IntValueNode(value="12345") + + +def test_ast_from_value_with_invalid_type(): + with pytest.raises(TypeError) as exc_info: + ast_from_value(4, None) + + assert "Unexpected input type: None." in str(exc_info.value) def test_ast_from_value_with_non_null_type_and_none(): typ = GraphQLNonNull(GraphQLInt) - assert ast_from_value(None, typ) is None + + with pytest.raises(GraphQLError) as exc_info: + ast_from_value(None, typ) + + assert "Received Null value for a Non-Null type Int." in str(exc_info.value) + + +def test_ast_from_serialized_value_untyped_typeerror(): + with pytest.raises(TypeError) as exc_info: + ast_from_serialized_value_untyped(GraphQLInt) + + assert "Cannot convert value to AST: Int." in str(exc_info.value) def test_variable_to_ast_type_passing_wrapping_type():