Skip to content

Commit 0926ed6

Browse files
authored
Fix dsl root operation types custom names (#320)
1 parent ea96294 commit 0926ed6

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

gql/dsl.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def __getattr__(self, name: str) -> "DSLType":
297297

298298
assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType))
299299

300-
return DSLType(type_def)
300+
return DSLType(type_def, self)
301301

302302

303303
class DSLSelector(ABC):
@@ -454,7 +454,27 @@ def is_valid_field(self, field: "DSLSelectable") -> bool:
454454
return operation_name != "SUBSCRIPTION"
455455

456456
elif isinstance(field, DSLField):
457-
return field.parent_type.name.upper() == operation_name
457+
458+
assert field.dsl_type is not None
459+
460+
schema = field.dsl_type._dsl_schema._schema
461+
462+
root_type = None
463+
464+
if operation_name == "QUERY":
465+
root_type = schema.query_type
466+
elif operation_name == "MUTATION":
467+
root_type = schema.mutation_type
468+
elif operation_name == "SUBSCRIPTION":
469+
root_type = schema.subscription_type
470+
471+
if root_type is None:
472+
log.error(
473+
f"Root type of type {operation_name} not found in the schema!"
474+
)
475+
return False
476+
477+
return field.parent_type.name == root_type.name
458478

459479
return False
460480

@@ -585,16 +605,22 @@ class DSLType:
585605
instances of :class:`DSLField`
586606
"""
587607

588-
def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]):
608+
def __init__(
609+
self,
610+
graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType],
611+
dsl_schema: DSLSchema,
612+
):
589613
"""Initialize the DSLType with the GraphQL type.
590614
591615
.. warning::
592616
Don't instantiate this class yourself.
593617
Use attributes of the :class:`DSLSchema` instead.
594618
595619
:param graphql_type: the GraphQL type definition from the schema
620+
:param dsl_schema: reference to the DSLSchema which created this type
596621
"""
597622
self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type
623+
self._dsl_schema = dsl_schema
598624
log.debug(f"Creating {self!r})")
599625

600626
def __getattr__(self, name: str) -> "DSLField":
@@ -611,7 +637,7 @@ def __getattr__(self, name: str) -> "DSLField":
611637
f"Field {name} does not exist in type {self._type.name}."
612638
)
613639

614-
return DSLField(formatted_name, self._type, field)
640+
return DSLField(formatted_name, self._type, field, self)
615641

616642
def __repr__(self) -> str:
617643
return f"<{self.__class__.__name__} {self._type!r}>"
@@ -763,6 +789,7 @@ def __init__(
763789
name: str,
764790
parent_type: Union[GraphQLObjectType, GraphQLInterfaceType],
765791
field: GraphQLField,
792+
dsl_type: Optional[DSLType] = None,
766793
):
767794
"""Initialize the DSLField.
768795
@@ -774,10 +801,12 @@ def __init__(
774801
:param parent_type: the GraphQL type definition from the schema of the
775802
parent type of the field
776803
:param field: the GraphQL field definition from the schema
804+
:param dsl_type: reference of the DSLType instance which created this field
777805
"""
778806
self.parent_type = parent_type
779807
self.field = field
780808
self.ast_field = FieldNode(name=NameNode(value=name), arguments=())
809+
self.dsl_type = dsl_type
781810

782811
log.debug(f"Creating {self!r}")
783812

tests/starwars/test_dsl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,42 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds):
728728
)
729729

730730

731+
def test_dsl_root_type_not_default():
732+
733+
from graphql import parse, build_ast_schema
734+
735+
schema_str = """
736+
schema {
737+
query: QueryNotDefault
738+
}
739+
740+
type QueryNotDefault {
741+
version: String
742+
}
743+
"""
744+
745+
type_def_ast = parse(schema_str)
746+
schema = build_ast_schema(type_def_ast)
747+
748+
ds = DSLSchema(schema)
749+
750+
query = dsl_gql(DSLQuery(ds.QueryNotDefault.version))
751+
752+
expected_query = """
753+
{
754+
version
755+
}
756+
"""
757+
assert print_ast(query) == expected_query.strip()
758+
759+
with pytest.raises(GraphQLError) as excinfo:
760+
DSLSubscription(ds.QueryNotDefault.version)
761+
762+
assert (
763+
"Invalid field for <DSLSubscription>: <DSLField QueryNotDefault::version>"
764+
) in str(excinfo.value)
765+
766+
731767
def test_dsl_gql_all_arguments_should_be_operations_or_fragments():
732768
with pytest.raises(
733769
TypeError, match="Operations should be instances of DSLExecutable "

0 commit comments

Comments
 (0)