@@ -297,7 +297,7 @@ def __getattr__(self, name: str) -> "DSLType":
297
297
298
298
assert isinstance (type_def , (GraphQLObjectType , GraphQLInterfaceType ))
299
299
300
- return DSLType (type_def )
300
+ return DSLType (type_def , self )
301
301
302
302
303
303
class DSLSelector (ABC ):
@@ -454,7 +454,27 @@ def is_valid_field(self, field: "DSLSelectable") -> bool:
454
454
return operation_name != "SUBSCRIPTION"
455
455
456
456
elif isinstance (field , DSLField ):
457
- return field .parent_type .name .upper () == operation_name
457
+
458
+ assert field .dsl_type is not None
459
+
460
+ schema = field .dsl_type ._dsl_schema ._schema
461
+
462
+ root_type = None
463
+
464
+ if operation_name == "QUERY" :
465
+ root_type = schema .query_type
466
+ elif operation_name == "MUTATION" :
467
+ root_type = schema .mutation_type
468
+ elif operation_name == "SUBSCRIPTION" :
469
+ root_type = schema .subscription_type
470
+
471
+ if root_type is None :
472
+ log .error (
473
+ f"Root type of type { operation_name } not found in the schema!"
474
+ )
475
+ return False
476
+
477
+ return field .parent_type .name == root_type .name
458
478
459
479
return False
460
480
@@ -585,16 +605,22 @@ class DSLType:
585
605
instances of :class:`DSLField`
586
606
"""
587
607
588
- def __init__ (self , graphql_type : Union [GraphQLObjectType , GraphQLInterfaceType ]):
608
+ def __init__ (
609
+ self ,
610
+ graphql_type : Union [GraphQLObjectType , GraphQLInterfaceType ],
611
+ dsl_schema : DSLSchema ,
612
+ ):
589
613
"""Initialize the DSLType with the GraphQL type.
590
614
591
615
.. warning::
592
616
Don't instantiate this class yourself.
593
617
Use attributes of the :class:`DSLSchema` instead.
594
618
595
619
:param graphql_type: the GraphQL type definition from the schema
620
+ :param dsl_schema: reference to the DSLSchema which created this type
596
621
"""
597
622
self ._type : Union [GraphQLObjectType , GraphQLInterfaceType ] = graphql_type
623
+ self ._dsl_schema = dsl_schema
598
624
log .debug (f"Creating { self !r} )" )
599
625
600
626
def __getattr__ (self , name : str ) -> "DSLField" :
@@ -611,7 +637,7 @@ def __getattr__(self, name: str) -> "DSLField":
611
637
f"Field { name } does not exist in type { self ._type .name } ."
612
638
)
613
639
614
- return DSLField (formatted_name , self ._type , field )
640
+ return DSLField (formatted_name , self ._type , field , self )
615
641
616
642
def __repr__ (self ) -> str :
617
643
return f"<{ self .__class__ .__name__ } { self ._type !r} >"
@@ -763,6 +789,7 @@ def __init__(
763
789
name : str ,
764
790
parent_type : Union [GraphQLObjectType , GraphQLInterfaceType ],
765
791
field : GraphQLField ,
792
+ dsl_type : Optional [DSLType ] = None ,
766
793
):
767
794
"""Initialize the DSLField.
768
795
@@ -774,10 +801,12 @@ def __init__(
774
801
:param parent_type: the GraphQL type definition from the schema of the
775
802
parent type of the field
776
803
:param field: the GraphQL field definition from the schema
804
+ :param dsl_type: reference of the DSLType instance which created this field
777
805
"""
778
806
self .parent_type = parent_type
779
807
self .field = field
780
808
self .ast_field = FieldNode (name = NameNode (value = name ), arguments = ())
809
+ self .dsl_type = dsl_type
781
810
782
811
log .debug (f"Creating { self !r} " )
783
812
0 commit comments