diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index fd485274..1c2c1c82 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -206,6 +206,35 @@ will generate a query equivalent to:: } } +Variable arguments with a default value +""""""""""""""""""""""""""""""""""""""" + +If you want to provide a **default value** for your variable, you can use +the :code:`default` method on a variable. + +The following code: + +.. code-block:: python + + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review=var.review.default({"stars": 5, "commentary": "Wow!"}), + episode=var.episode, + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + +will generate a query equivalent to:: + + mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { + createReview(review: $review, episode: $episode) { + stars + commentary + } + } + Subscriptions ^^^^^^^^^^^^^ diff --git a/gql/dsl.py b/gql/dsl.py index 63b71a07..7f09b928 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -18,6 +18,7 @@ FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLEnumType, GraphQLError, GraphQLField, GraphQLID, @@ -28,9 +29,9 @@ GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLSchema, GraphQLString, - GraphQLWrappingType, InlineFragmentNode, IntValueNode, ListTypeNode, @@ -50,7 +51,6 @@ ValueNode, VariableDefinitionNode, VariableNode, - assert_named_type, get_named_type, introspection_types, is_enum_type, @@ -134,7 +134,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): - return value.set_type(type_).ast_variable + return value.set_type(type_).ast_variable_name if is_non_null_type(type_): type_ = cast(GraphQLNonNull, type_) @@ -529,26 +529,33 @@ class DSLVariable: def __init__(self, name: str): """:meta private:""" - self.type: Optional[TypeNode] = None self.name = name - self.ast_variable = VariableNode(name=NameNode(value=self.name)) + self.ast_variable_type: Optional[TypeNode] = None + self.ast_variable_name = VariableNode(name=NameNode(value=self.name)) + self.default_value = None + self.type: Optional[GraphQLInputType] = None - def to_ast_type( - self, type_: Union[GraphQLWrappingType, GraphQLNamedType] - ) -> TypeNode: + def to_ast_type(self, type_: GraphQLInputType) -> TypeNode: if is_wrapping_type(type_): if isinstance(type_, GraphQLList): return ListTypeNode(type=self.to_ast_type(type_.of_type)) + elif isinstance(type_, GraphQLNonNull): return NonNullTypeNode(type=self.to_ast_type(type_.of_type)) - type_ = assert_named_type(type_) + assert isinstance( + type_, (GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType) + ) + return NamedTypeNode(name=NameNode(value=type_.name)) - def set_type( - self, type_: Union[GraphQLWrappingType, GraphQLNamedType] - ) -> "DSLVariable": - self.type = self.to_ast_type(type_) + def set_type(self, type_: GraphQLInputType) -> "DSLVariable": + self.type = type_ + self.ast_variable_type = self.to_ast_type(type_) + return self + + def default(self, default_value: Any) -> "DSLVariable": + self.default_value = default_value return self @@ -581,9 +588,11 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: """ return tuple( VariableDefinitionNode( - type=var.type, - variable=var.ast_variable, - default_value=None, + type=var.ast_variable_type, + variable=var.ast_variable_name, + default_value=None + if var.default_value is None + else ast_from_value(var.default_value, var.type), ) for var in self.variables.values() if var.type is not None # only variables used diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 0b881806..d021e122 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -111,11 +111,11 @@ def test_ast_from_serialized_value_untyped_typeerror(): def test_variable_to_ast_type_passing_wrapping_type(): - wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("Droid"))) - variable = DSLVariable("droids") + wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("ReviewInput"))) + variable = DSLVariable("review_input") ast = variable.to_ast_type(wrapping_type) assert ast == NonNullTypeNode( - type=ListTypeNode(type=NamedTypeNode(name=NameNode(value="Droid"))) + type=ListTypeNode(type=NamedTypeNode(name=NameNode(value="ReviewInput"))) ) @@ -170,6 +170,50 @@ def test_add_variable_definitions(ds): ) +def test_add_variable_definitions_with_default_value_enum(ds): + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review=var.review, episode=var.episode.default(4) + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """mutation ($review: ReviewInput, $episode: Episode = NEWHOPE) { + createReview(review: $review, episode: $episode) { + stars + commentary + } +}""" + ) + + +def test_add_variable_definitions_with_default_value_input_object(ds): + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review=var.review.default({"stars": 5, "commentary": "Wow!"}), + episode=var.episode, + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """ +mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { + createReview(review: $review, episode: $episode) { + stars + commentary + } +}""".strip() + ) + + def test_add_variable_definitions_in_input_object(ds): var = DSLVariableDefinitions() op = DSLMutation(