@@ -896,3 +896,84 @@ def default_resolver(root, info, ctx, **args):
896
896
schema .default_resolver = default_resolver
897
897
898
898
validate_schema (schema )
899
+
900
+
901
+ class TestInputTypeCycles :
902
+ def _schema (self , input_types ):
903
+ return Schema (
904
+ ObjectType (
905
+ "query" ,
906
+ [
907
+ Field (
908
+ "field" ,
909
+ Int ,
910
+ args = [Argument (t .name , t ) for t in input_types ],
911
+ )
912
+ ],
913
+ )
914
+ )
915
+
916
+ def test_no_cycle (self ):
917
+ A = InputObjectType ("A" , [InputField ("f" , Int )])
918
+ B = InputObjectType ("B" , [InputField ("f" , Int )])
919
+ C = InputObjectType ("C" , [InputField ("a" , lambda : NonNullType (A ))])
920
+ schema = self ._schema ([A , B , C ])
921
+ assert validate_schema (schema )
922
+
923
+ def test_simple_cycles (self ):
924
+ A = InputObjectType ("A" , [InputField ("b" , lambda : NonNullType (B ))])
925
+ B = InputObjectType ("B" , [InputField ("c" , lambda : NonNullType (C ))])
926
+ C = InputObjectType ("C" , [InputField ("a" , lambda : NonNullType (A ))])
927
+ schema = self ._schema ([A , B , C ])
928
+
929
+ with pytest .raises (SchemaError ) as exc_info :
930
+ validate_schema (schema )
931
+
932
+ assert set ([str (e ) for e in exc_info .value .errors ]) == set (
933
+ [
934
+ 'Non breakable input chain found: "B" > "C" > "A" > "B"' ,
935
+ 'Non breakable input chain found: "A" > "B" > "C" > "A"' ,
936
+ 'Non breakable input chain found: "C" > "A" > "B" > "C"' ,
937
+ ]
938
+ )
939
+
940
+ def test_multiple_cycles (self ):
941
+ A = InputObjectType (
942
+ "A" ,
943
+ [
944
+ InputField ("b" , lambda : NonNullType (B )),
945
+ InputField ("c" , lambda : NonNullType (C )),
946
+ ],
947
+ )
948
+ B = InputObjectType ("B" , [InputField ("a" , lambda : NonNullType (A ))])
949
+ C = InputObjectType ("C" , [InputField ("a" , lambda : NonNullType (A ))])
950
+ schema = self ._schema ([A , B , C ])
951
+
952
+ with pytest .raises (SchemaError ) as exc_info :
953
+ validate_schema (schema )
954
+
955
+ assert set ([str (e ) for e in exc_info .value .errors ]) == set (
956
+ [
957
+ 'Non breakable input chain found: "C" > "A" > "C"' ,
958
+ 'Non breakable input chain found: "A" > "C" > "A"' ,
959
+ 'Non breakable input chain found: "A" > "B" > "A"' ,
960
+ 'Non breakable input chain found: "B" > "A" > "B"' ,
961
+ ]
962
+ )
963
+
964
+ def test_simple_breakable_cycle (self ):
965
+ A = InputObjectType ("A" , [InputField ("b" , lambda : NonNullType (B ))])
966
+ B = InputObjectType ("B" , [InputField ("c" , lambda : C )])
967
+ C = InputObjectType ("C" , [InputField ("a" , lambda : NonNullType (A ))])
968
+ schema = self ._schema ([A , B , C ])
969
+ assert validate_schema (schema )
970
+
971
+ def test_list_breaks_cycle (self ):
972
+ A = InputObjectType ("A" , [InputField ("b" , lambda : NonNullType (B ))])
973
+ B = InputObjectType (
974
+ "B" ,
975
+ [InputField ("c" , lambda : NonNullType (ListType (NonNullType (C ))))],
976
+ )
977
+ C = InputObjectType ("C" , [InputField ("a" , lambda : NonNullType (A ))])
978
+ schema = self ._schema ([A , B , C ])
979
+ assert validate_schema (schema )
0 commit comments