From 94abe44e1e3d33f453cdc29315ffb0a36d847b7a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 27 Mar 2022 23:02:20 +0200 Subject: [PATCH 1/4] Check root types names from the schema --- gql/dsl.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 6a2e0718..09bf30a8 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -294,7 +294,7 @@ def __getattr__(self, name: str) -> "DSLType": assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)) - return DSLType(type_def) + return DSLType(type_def, self) class DSLSelector(ABC): @@ -445,7 +445,19 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: return operation_name != "SUBSCRIPTION" elif isinstance(field, DSLField): - return field.parent_type.name.upper() == operation_name + + assert field.dsl_type is not None + + schema = field.dsl_type.dsl_schema._schema + + if operation_name == "QUERY": + root_type = schema.query_type + elif operation_name == "MUTATION": + root_type = schema.mutation_type + elif operation_name == "SUBSCRIPTION": + root_type = schema.subscription_type + + return field.parent_type.name == root_type.name return False @@ -574,7 +586,11 @@ class DSLType: instances of :class:`DSLField` """ - def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]): + def __init__( + self, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], + dsl_schema: DSLSchema + ): """Initialize the DSLType with the GraphQL type. .. warning:: @@ -582,8 +598,10 @@ def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]) Use attributes of the :class:`DSLSchema` instead. :param graphql_type: the GraphQL type definition from the schema + :param dsl_schema: reference to the DSLSchema which created this type """ self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type + self.dsl_schema = dsl_schema log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": @@ -600,7 +618,7 @@ def __getattr__(self, name: str) -> "DSLField": f"Field {name} does not exist in type {self._type.name}." ) - return DSLField(formatted_name, self._type, field) + return DSLField(formatted_name, self._type, field, self) def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" @@ -752,6 +770,7 @@ def __init__( name: str, parent_type: Union[GraphQLObjectType, GraphQLInterfaceType], field: GraphQLField, + dsl_type: Optional[DSLType] = None, ): """Initialize the DSLField. @@ -763,10 +782,12 @@ def __init__( :param parent_type: the GraphQL type definition from the schema of the parent type of the field :param field: the GraphQL field definition from the schema + :param dsl_type: reference of the DSLType instance which created this field """ self.parent_type = parent_type self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=()) + self.dsl_type = dsl_type log.debug(f"Creating {self!r}") From 06f3cd0ab31fb1634de1edf56e10d9d7cf6d3239 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 11 Apr 2022 17:08:48 +0200 Subject: [PATCH 2/4] Support case when root type is not present in schema --- gql/dsl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gql/dsl.py b/gql/dsl.py index 41751ece..4b63cadc 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -456,6 +456,8 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: schema = field.dsl_type.dsl_schema._schema + root_type = None + if operation_name == "QUERY": root_type = schema.query_type elif operation_name == "MUTATION": @@ -463,6 +465,10 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: elif operation_name == "SUBSCRIPTION": root_type = schema.subscription_type + if root_type is None: + log.error(f"Root type of type {operation_name} not found in the schema!") + return False + return field.parent_type.name == root_type.name return False From 4e63d95bf58e427387ab155ff6c018f52362f2bf Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 11 Apr 2022 17:51:18 +0200 Subject: [PATCH 3/4] rename dsl_schema to _dsl_schema --- gql/dsl.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 4b63cadc..448114fd 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -454,7 +454,7 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: assert field.dsl_type is not None - schema = field.dsl_type.dsl_schema._schema + schema = field.dsl_type._dsl_schema._schema root_type = None @@ -466,7 +466,9 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: root_type = schema.subscription_type if root_type is None: - log.error(f"Root type of type {operation_name} not found in the schema!") + log.error( + f"Root type of type {operation_name} not found in the schema!" + ) return False return field.parent_type.name == root_type.name @@ -615,7 +617,7 @@ def __init__( :param dsl_schema: reference to the DSLSchema which created this type """ self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type - self.dsl_schema = dsl_schema + self._dsl_schema = dsl_schema log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": From 95b50d8f9ab51c83c1722556061835c7efa768b0 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 11 Apr 2022 17:51:38 +0200 Subject: [PATCH 4/4] Add tests --- tests/starwars/test_dsl.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 50f5449c..0b765f23 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -712,6 +712,42 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): ) +def test_dsl_root_type_not_default(): + + from graphql import parse, build_ast_schema + + schema_str = """ +schema { + query: QueryNotDefault +} + +type QueryNotDefault { + version: String +} +""" + + type_def_ast = parse(schema_str) + schema = build_ast_schema(type_def_ast) + + ds = DSLSchema(schema) + + query = dsl_gql(DSLQuery(ds.QueryNotDefault.version)) + + expected_query = """ +{ + version +} +""" + assert print_ast(query) == expected_query.strip() + + with pytest.raises(GraphQLError) as excinfo: + DSLSubscription(ds.QueryNotDefault.version) + + assert ( + "Invalid field for : " + ) in str(excinfo.value) + + def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( TypeError, match="Operations should be instances of DSLExecutable "