diff --git a/gql/dsl.py b/gql/dsl.py index 26b9f426..63b71a07 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -297,7 +297,7 @@ def __getattr__(self, name: str) -> "DSLType": assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)) - return DSLType(type_def) + return DSLType(type_def, self) class DSLSelector(ABC): @@ -454,7 +454,27 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: return operation_name != "SUBSCRIPTION" elif isinstance(field, DSLField): - return field.parent_type.name.upper() == operation_name + + assert field.dsl_type is not None + + schema = field.dsl_type._dsl_schema._schema + + root_type = None + + if operation_name == "QUERY": + root_type = schema.query_type + elif operation_name == "MUTATION": + root_type = schema.mutation_type + elif operation_name == "SUBSCRIPTION": + root_type = schema.subscription_type + + if root_type is None: + log.error( + f"Root type of type {operation_name} not found in the schema!" + ) + return False + + return field.parent_type.name == root_type.name return False @@ -585,7 +605,11 @@ class DSLType: instances of :class:`DSLField` """ - def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]): + def __init__( + self, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], + dsl_schema: DSLSchema, + ): """Initialize the DSLType with the GraphQL type. .. warning:: @@ -593,8 +617,10 @@ def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]) Use attributes of the :class:`DSLSchema` instead. :param graphql_type: the GraphQL type definition from the schema + :param dsl_schema: reference to the DSLSchema which created this type """ self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type + self._dsl_schema = dsl_schema log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": @@ -611,7 +637,7 @@ def __getattr__(self, name: str) -> "DSLField": f"Field {name} does not exist in type {self._type.name}." ) - return DSLField(formatted_name, self._type, field) + return DSLField(formatted_name, self._type, field, self) def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" @@ -763,6 +789,7 @@ def __init__( name: str, parent_type: Union[GraphQLObjectType, GraphQLInterfaceType], field: GraphQLField, + dsl_type: Optional[DSLType] = None, ): """Initialize the DSLField. @@ -774,10 +801,12 @@ def __init__( :param parent_type: the GraphQL type definition from the schema of the parent type of the field :param field: the GraphQL field definition from the schema + :param dsl_type: reference of the DSLType instance which created this field """ self.parent_type = parent_type self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=()) + self.dsl_type = dsl_type log.debug(f"Creating {self!r}") diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index c0f2b441..0b881806 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -728,6 +728,42 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): ) +def test_dsl_root_type_not_default(): + + from graphql import parse, build_ast_schema + + schema_str = """ +schema { + query: QueryNotDefault +} + +type QueryNotDefault { + version: String +} +""" + + type_def_ast = parse(schema_str) + schema = build_ast_schema(type_def_ast) + + ds = DSLSchema(schema) + + query = dsl_gql(DSLQuery(ds.QueryNotDefault.version)) + + expected_query = """ +{ + version +} +""" + assert print_ast(query) == expected_query.strip() + + with pytest.raises(GraphQLError) as excinfo: + DSLSubscription(ds.QueryNotDefault.version) + + assert ( + "Invalid field for : " + ) in str(excinfo.value) + + def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( TypeError, match="Operations should be instances of DSLExecutable "