Skip to content

Feat: Add support for custom scalar types [updated] #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1bbfbd7
Include support for GQL custom scalar types
Apr 20, 2018
9d80656
Fix failing tests
Apr 21, 2018
a13d166
Rename ResponseParser to TypeAdaptor
Apr 21, 2018
44becd0
Clean up docstrings in `type_adaptor` file
Apr 30, 2018
0457275
Bugfix: Empty headers should default to empty dict
Apr 30, 2018
5908011
Add new exceptions (GQLSyntaxError and GQLServerError)
Apr 30, 2018
0fd8937
Address CR changes
May 1, 2018
2988ed5
Remove type annotations from TypeAdapter for py2.7 compatibility
May 1, 2018
e0cc6b1
Merge branch 'master' into awais/waveussd-to-graphql
KingDarBoja Jun 14, 2020
5839be8
styles: apply black, flake8 and isort formatting
KingDarBoja Jun 14, 2020
7e43255
tests: use vcr to reproduce original schema
KingDarBoja Jun 22, 2020
53fd1cb
refactor: add better typings to type adapter
KingDarBoja Jun 22, 2020
6d83048
fix: correct check of type_adapter on client
KingDarBoja Jun 22, 2020
dd15658
Merge branch 'master' into custom-scalar-types
KingDarBoja Jun 29, 2020
5b19de5
Fix mypy issues and bring extra type hints
KingDarBoja Jun 29, 2020
598acd8
Fix type_adapter for subscriptions and async transports
leszekhanusz Jun 28, 2020
70f9ab2
Merge branch 'master' of https://github.com/graphql-python/gql into c…
leszekhanusz Jun 29, 2020
a98b272
Fix type hints - scalar values can be str or int or something else
leszekhanusz Jun 29, 2020
98475a8
Merge branch 'master' into custom-scalar-types
KingDarBoja Jul 4, 2020
33dc6e7
Merge branch 'master' into custom-scalar-types
KingDarBoja Jul 19, 2020
1ad5903
Merge branch 'master' into custom-scalar-types
KingDarBoja Aug 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .transport.exceptions import TransportQueryError
from .transport.local_schema import LocalSchemaTransport
from .transport.transport import Transport
from .type_adapter import TypeAdapter


class Client:
Expand All @@ -28,6 +29,7 @@ def __init__(
transport: Optional[Union[Transport, AsyncTransport]] = None,
fetch_schema_from_transport: bool = False,
execute_timeout: Optional[int] = 10,
custom_types: Optional[Dict[str, Any]] = None,
):
assert not (
type_def and introspection
Expand Down Expand Up @@ -77,10 +79,21 @@ def __init__(
# Enforced timeout of the execute function
self.execute_timeout = execute_timeout

# Fetch schema from transport directly if we are using a sync transport
if isinstance(transport, Transport) and fetch_schema_from_transport:
with self as session:
session.fetch_schema()

# Dictionary where the name of the custom scalar type is the key and the
# value is a class which has a `parse_value()` function
self.custom_types = custom_types

# Create a type_adapter instance directly here if we received the schema
# locally or from a sync transport
self.type_adapter = (
TypeAdapter(schema, custom_types) if custom_types and schema else None
)

def validate(self, document):
assert (
self.schema
Expand Down Expand Up @@ -233,7 +246,7 @@ def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:

return self.transport.execute(document, *args, **kwargs)

def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]:

# Validate and execute on the transport
result = self._execute(document, *args, **kwargs)
Expand All @@ -248,7 +261,10 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

return result.data
if self.client.type_adapter:
return self.client.type_adapter.convert_scalars(result.data)
else:
return result.data

def fetch_schema(self) -> None:
execution_result = self.transport.execute(parse(get_introspection_query()))
Expand Down Expand Up @@ -280,6 +296,13 @@ async def fetch_and_validate(self, document: DocumentNode):
if self.client.fetch_schema_from_transport and not self.client.schema:
await self.fetch_schema()

# Once we have received the schema from the async transport,
# we can create a TypeAdapter instance if the user provided custom types
if self.client.custom_types and self.client.schema:
self.client.type_adapter = TypeAdapter(
self.client.schema, self.client.custom_types
)

# Validate document
if self.client.schema:
self.client.validate(document)
Expand Down Expand Up @@ -310,7 +333,7 @@ async def _subscribe(

async def subscribe(
self, document: DocumentNode, *args, **kwargs
) -> AsyncGenerator[Dict, None]:
) -> AsyncGenerator[Dict[str, Any], None]:

# Validate and subscribe on the transport
async for result in self._subscribe(document, *args, **kwargs):
Expand All @@ -322,7 +345,10 @@ async def subscribe(
)

elif result.data is not None:
yield result.data
if self.client.type_adapter:
yield self.client.type_adapter.convert_scalars(result.data)
else:
yield result.data

async def _execute(
self, document: DocumentNode, *args, **kwargs
Expand All @@ -337,7 +363,7 @@ async def _execute(
self.client.execute_timeout,
)

async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]:

# Validate and execute on the transport
result = await self._execute(document, *args, **kwargs)
Expand All @@ -352,7 +378,10 @@ async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

return result.data
if self.client.type_adapter:
return self.client.type_adapter.convert_scalars(result.data)
else:
return result.data

async def fetch_schema(self) -> None:
execution_result = await self.transport.execute(
Expand Down
137 changes: 137 additions & 0 deletions gql/type_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Any, Dict, List, Optional, Union, cast

from graphql import GraphQLSchema
from graphql.type.definition import GraphQLField, GraphQLObjectType, GraphQLScalarType


class TypeAdapter:
"""Substitute custom scalars in a GQL response with their decoded counterparts.

GQL custom scalar types are defined on the GQL schema and are used to represent
fields which have special behaviour. To define custom scalar type, you need
the type name, and a class which has a class method called `parse_value()` -
this is the function which will be used to deserialize the custom scalar field.

We first need iterate over all the fields in the response (which is done in
the `_traverse()` function).

Each time we find a field which is a custom scalar (it's type name appears
as a key in self.custom_types), we replace the value of that field with the
decoded value. All of this logic happens in `_substitute()`.

Public Interface:
apply(): pass in a GQL response to replace all instances of custom
scalar strings with their deserialized representation."""

def __init__(
self, schema: GraphQLSchema, custom_types: Optional[Dict[str, Any]] = None
):
self.schema = schema
self.custom_types = custom_types or {}

@staticmethod
def _follow_type_chain(node):
""" Get the type of the schema node in question.

In the GraphQL schema, GraphQLFields have a "type" property. However, often
that dict has an "of_type" property itself. In order to get to the actual
type, we need to indefinitely follow the chain of "of_type" fields to get
to the last one, which is the one we care about."""
if isinstance(node, GraphQLObjectType):
return node

field_type = node.type
while hasattr(field_type, "of_type"):
field_type = field_type.of_type

return field_type

def _get_scalar_type_name(self, field):
"""Returns the name of the type if the type is a scalar type.
Returns `None` otherwise"""
node = self._follow_type_chain(field)
if isinstance(node, GraphQLScalarType):
return node.name
return None

def _lookup_scalar_type(self, keys: List[str]):
"""Search through the GQL schema and return the type identified by 'keys'.

If keys (e.g. ['film', 'release_date']) points to a scalar type, then
this function returns the name of that type. (e.g. 'DateTime')

If it is not a scalar type (e..g a GraphQLObject), then this
function returns `None`.

`keys` is a breadcrumb trail telling us where to look in the GraphQL schema.
By default the root level is `schema.query`, if that fails, then we check
`schema.mutation`."""

def traverse_schema(
node: Optional[Union[GraphQLObjectType, GraphQLField]], lookup
):
if not lookup:
return self._get_scalar_type_name(node)

final_node = self._follow_type_chain(node)
return traverse_schema(final_node.fields[lookup[0]], lookup[1:])

if self.schema.query_type and keys[0] in self.schema.query_type.fields:
schema_root = self.schema.query_type
elif self.schema.mutation_type and keys[0] in self.schema.mutation_type.fields:
schema_root = self.schema.mutation_type
elif (
self.schema.subscription_type
and keys[0] in self.schema.subscription_type.fields
):
schema_root = self.schema.subscription_type
else:
return None

try:
return traverse_schema(schema_root, keys)
except (KeyError, AttributeError):
return None

def _get_decoded_scalar_type(self, keys: List[str], value: Any) -> Any:
"""Get the decoded value of the type identified by `keys`.

If the type is not a custom scalar, then return the original value.

If it is a custom scalar, return the deserialized value, as
output by `<CustomScalarType>.parse_value()`"""

scalar_type = self._lookup_scalar_type(keys)
if scalar_type and scalar_type in self.custom_types:
return self.custom_types[scalar_type].parse_value(value)
return value

def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively traverse the GQL response

Recursively traverses the GQL response and calls _get_decoded_scalar_type()
for all leaf nodes.

The function is called with 2 arguments:
keys: is a breadcrumb trail telling us where we are in the
response, and therefore, where to look in the GQL Schema.
value: is the value at that node in the response

Builds a new tree with the substituted values so old `response` is not
modified."""

def iterate(
node: Union[List, Dict, str], keys: List[str] = None
) -> Union[Dict[str, Any], List, Any]:
if keys is None:
keys = []
if isinstance(node, dict):
return {
_key: iterate(value, keys + [_key]) for _key, value in node.items()
}
elif isinstance(node, list):
return [(iterate(item, keys)) for item in node]
else:
return self._get_decoded_scalar_type(keys, node)

return cast(Dict, iterate(response))
44 changes: 44 additions & 0 deletions tests/test_async_client_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,47 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu

with pytest.raises(graphql.error.GraphQLError):
await session.execute(query)


class ToLowercase:
@staticmethod
def parse_value(value: str):
return value.lower()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [hero_server_answers], indirect=True)
async def test_async_client_validation_fetch_schema_from_server_with_custom_types(
event_loop, server
):

url = f"ws://{server.hostname}:{server.port}/graphql"

sample_transport = WebsocketsTransport(url=url)

custom_types = {"String": ToLowercase}

async with Client(
transport=sample_transport,
fetch_schema_from_transport=True,
custom_types=custom_types,
) as session:

query = gql(
"""
query HeroNameQuery {
hero {
name
}
}
"""
)

result = await session.execute(query)

print("Client received:", result)

# The expected hero name is now in lowercase
expected = {"hero": {"name": "r2-d2"}}

assert result == expected
56 changes: 56 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
)
from gql.transport.requests import RequestsHTTPTransport

from .test_type_adapter import Capitalize

query1_str = """
query getContinents {
continents {
Expand Down Expand Up @@ -201,3 +203,57 @@ def test_code():
sample_transport.execute(query)

await run_sync_test(event_loop, server, test_code)


partial_schema = """

type Continent {
code: ID!
name: String!
}

type Query {
continents: [Continent!]!
}

"""


@pytest.mark.asyncio
async def test_requests_query_with_custom_types(event_loop, aiohttp_server):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
sample_transport = RequestsHTTPTransport(url=url)

custom_types = {"String": Capitalize}

# Instanciate a client which will capitalize all the String scalars
with Client(
transport=sample_transport,
type_def=partial_schema,
custom_types=custom_types,
) as session:

query = gql(query1_str)

# Execute query synchronously
result = session.execute(query)

continents = result["continents"]

africa = continents[0]

assert africa["code"] == "AF"

# Check that the string is capitalized
assert africa["name"] == "AFRICA"

await run_sync_test(event_loop, server, test_code)
Loading