Skip to content

Commit bc706f3

Browse files
committed
Validate unbreakable cycles of input types.
This is based on a spec proposal (see: graphql/graphql-spec#701) and may change in the future.
1 parent 321e126 commit bc706f3

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

src/py_gql/schema/validation.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@
33
Schema validation utility
44
"""
55

6+
import collections
67
import re
78
from inspect import Parameter, signature
8-
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Set, Union
9+
from typing import (
10+
TYPE_CHECKING,
11+
Any,
12+
Callable,
13+
Dict,
14+
List,
15+
Sequence,
16+
Set,
17+
Tuple,
18+
Union,
19+
)
920

1021
from .._string_utils import quoted_options_list
1122
from ..exc import SchemaError, SchemaValidationError
@@ -18,6 +29,7 @@
1829
EnumValue,
1930
InputObjectType,
2031
InterfaceType,
32+
ListType,
2133
NamedType,
2234
NonNullType,
2335
ObjectType,
@@ -167,6 +179,7 @@ def __call__(self) -> None:
167179
self.add_error("%s is not a valid schema type" % type_)
168180

169181
self.validate_directives()
182+
self.validate_cyclic_input_types()
170183

171184
def validate_root_types(self) -> None:
172185
query = self.schema.query_type
@@ -507,3 +520,74 @@ def validate_input_fields(self, input_object: InputObjectType) -> None:
507520
)
508521

509522
fieldnames.add(field.name)
523+
524+
def validate_cyclic_input_types(self) -> None:
525+
"""
526+
Detect unbroken chains of input types.
527+
528+
Generally input types can refer to themselves as long as it is through a
529+
nullable type or a list, non nullable cycles are not supported.
530+
531+
This is currently (2020-10-31) `in the process of stabilising
532+
<https://github.com/graphql/graphql-spec/pull/701/>`_ and may change in
533+
the future.
534+
"""
535+
# TODO: Add link to spec / RFC in errors when stabilised.
536+
input_types = [
537+
t
538+
for t in self.schema.types.values()
539+
if isinstance(t, InputObjectType)
540+
]
541+
542+
direct_references = collections.defaultdict(set)
543+
544+
# Collect any non breakable reference to any input object type.
545+
for t in input_types:
546+
for f in t.fields:
547+
real_type = f.type
548+
549+
# Non null types are breakable by default, wrapped types are not.
550+
breakable = not isinstance(real_type, (ListType, NonNullType))
551+
while isinstance(real_type, (ListType, NonNullType)):
552+
# List can break the chain.
553+
if isinstance(real_type, ListType):
554+
breakable = True
555+
real_type = real_type.type
556+
557+
if (not breakable) and isinstance(real_type, InputObjectType):
558+
direct_references[t].add(real_type)
559+
560+
chains = [] # type: List[Tuple[str, Dict[str, List[str]]]]
561+
562+
def _search(outer, acc=None, path=None):
563+
acc, path = acc or set(), path or ()
564+
565+
for inner in direct_references[outer]:
566+
if inner.name in path:
567+
break
568+
569+
if (inner.name, path) in acc:
570+
break
571+
572+
acc.add((inner.name, path))
573+
_search(inner, acc, (*path, inner.name))
574+
575+
return acc
576+
577+
all_chains = [
578+
(t.name, _search(t)) for t in list(direct_references.keys())
579+
]
580+
581+
# TODO: This will contain multiple rotated versions of any given cycle.
582+
# This is fine for now, but would be nice to avoid duplicate data.
583+
for typename, chains in all_chains:
584+
for final, path in chains:
585+
if final == typename:
586+
self.add_error(
587+
"Non breakable input chain found: %s"
588+
% quoted_options_list(
589+
[typename, *path, typename],
590+
separator=" > ",
591+
final_separator=" > ",
592+
)
593+
)

tests/test_schema/test_validation.py

+81
Original file line numberDiff line numberDiff line change
@@ -896,3 +896,84 @@ def default_resolver(root, info, ctx, **args):
896896
schema.default_resolver = default_resolver
897897

898898
validate_schema(schema)
899+
900+
901+
class TestInputTypeCycles:
902+
def _schema(self, input_types):
903+
return Schema(
904+
ObjectType(
905+
"query",
906+
[
907+
Field(
908+
"field",
909+
Int,
910+
args=[Argument(t.name, t) for t in input_types],
911+
)
912+
],
913+
)
914+
)
915+
916+
def test_no_cycle(self):
917+
A = InputObjectType("A", [InputField("f", Int)])
918+
B = InputObjectType("B", [InputField("f", Int)])
919+
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
920+
schema = self._schema([A, B, C])
921+
assert validate_schema(schema)
922+
923+
def test_simple_cycles(self):
924+
A = InputObjectType("A", [InputField("b", lambda: NonNullType(B))])
925+
B = InputObjectType("B", [InputField("c", lambda: NonNullType(C))])
926+
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
927+
schema = self._schema([A, B, C])
928+
929+
with pytest.raises(SchemaError) as exc_info:
930+
validate_schema(schema)
931+
932+
assert set([str(e) for e in exc_info.value.errors]) == set(
933+
[
934+
'Non breakable input chain found: "B" > "C" > "A" > "B"',
935+
'Non breakable input chain found: "A" > "B" > "C" > "A"',
936+
'Non breakable input chain found: "C" > "A" > "B" > "C"',
937+
]
938+
)
939+
940+
def test_multiple_cycles(self):
941+
A = InputObjectType(
942+
"A",
943+
[
944+
InputField("b", lambda: NonNullType(B)),
945+
InputField("c", lambda: NonNullType(C)),
946+
],
947+
)
948+
B = InputObjectType("B", [InputField("a", lambda: NonNullType(A))])
949+
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
950+
schema = self._schema([A, B, C])
951+
952+
with pytest.raises(SchemaError) as exc_info:
953+
validate_schema(schema)
954+
955+
assert set([str(e) for e in exc_info.value.errors]) == set(
956+
[
957+
'Non breakable input chain found: "C" > "A" > "C"',
958+
'Non breakable input chain found: "A" > "C" > "A"',
959+
'Non breakable input chain found: "A" > "B" > "A"',
960+
'Non breakable input chain found: "B" > "A" > "B"',
961+
]
962+
)
963+
964+
def test_simple_breakable_cycle(self):
965+
A = InputObjectType("A", [InputField("b", lambda: NonNullType(B))])
966+
B = InputObjectType("B", [InputField("c", lambda: C)])
967+
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
968+
schema = self._schema([A, B, C])
969+
assert validate_schema(schema)
970+
971+
def test_list_breaks_cycle(self):
972+
A = InputObjectType("A", [InputField("b", lambda: NonNullType(B))])
973+
B = InputObjectType(
974+
"B",
975+
[InputField("c", lambda: NonNullType(ListType(NonNullType(C))))],
976+
)
977+
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
978+
schema = self._schema([A, B, C])
979+
assert validate_schema(schema)

0 commit comments

Comments
 (0)