Skip to content

Fixed the bug where a nested GraphQLInputObjectType causing infinite get_arg_serializer recursive calls #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions gql/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def args(self, **kwargs):
arg = self.field.args.get(name)
if not arg:
raise KeyError(f"Argument {name} does not exist in {self.field}.")
arg_type_serializer = get_arg_serializer(arg.type)
arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict())
serialized_value = arg_type_serializer(value)
added_args.append(
ArgumentNode(name=NameNode(value=name), value=serialized_value)
Expand Down Expand Up @@ -151,21 +151,28 @@ def serialize_list(serializer, list_values):
return ListValueNode(values=FrozenList(serializer(v) for v in list_values))


def get_arg_serializer(arg_type):
def get_arg_serializer(arg_type, known_serializers):
if isinstance(arg_type, GraphQLNonNull):
return get_arg_serializer(arg_type.of_type)
return get_arg_serializer(arg_type.of_type, known_serializers)
if isinstance(arg_type, GraphQLInputField):
return get_arg_serializer(arg_type.type)
return get_arg_serializer(arg_type.type, known_serializers)
if isinstance(arg_type, GraphQLInputObjectType):
serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()}
return lambda value: ObjectValueNode(
if arg_type in known_serializers:
return known_serializers[arg_type]
known_serializers[arg_type] = None
serializers = {
k: get_arg_serializer(v, known_serializers)
for k, v in arg_type.fields.items()
}
known_serializers[arg_type] = lambda value: ObjectValueNode(
fields=FrozenList(
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
for k, v in value.items()
)
)
return known_serializers[arg_type]
if isinstance(arg_type, GraphQLList):
inner_serializer = get_arg_serializer(arg_type.of_type)
inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers)
return partial(serialize_list, inner_serializer)
if isinstance(arg_type, GraphQLEnumType):
return lambda value: EnumValueNode(value=arg_type.serialize(value))
Expand Down
Empty file added tests/nested_input/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions tests/nested_input/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from graphql import (
GraphQLArgument,
GraphQLField,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInt,
GraphQLObjectType,
GraphQLSchema,
)

nestedInput = GraphQLInputObjectType(
"Nested",
description="The input object that has a field pointing to itself",
fields={"foo": GraphQLInputField(GraphQLInt, description="foo")},
)

nestedInput.fields["child"] = GraphQLInputField(nestedInput, description="child")

queryType = GraphQLObjectType(
"Query",
fields=lambda: {
"foo": GraphQLField(
args={"nested": GraphQLArgument(type_=nestedInput)},
resolve=lambda *args, **kwargs: 1,
type_=GraphQLInt,
),
},
)

NestedInputSchema = GraphQLSchema(query=queryType, types=[nestedInput],)
63 changes: 63 additions & 0 deletions tests/nested_input/test_nested_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from functools import partial

import pytest
from graphql import (
EnumValueNode,
GraphQLEnumType,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLList,
GraphQLNonNull,
NameNode,
ObjectFieldNode,
ObjectValueNode,
ast_from_value,
)
from graphql.pyutils import FrozenList

import gql.dsl as dsl
from gql import Client
from gql.dsl import DSLSchema, serialize_list
from tests.nested_input.schema import NestedInputSchema

# back up the new func
new_get_arg_serializer = dsl.get_arg_serializer


def old_get_arg_serializer(arg_type, known_serializers=None):
if isinstance(arg_type, GraphQLNonNull):
return old_get_arg_serializer(arg_type.of_type)
if isinstance(arg_type, GraphQLInputField):
return old_get_arg_serializer(arg_type.type)
if isinstance(arg_type, GraphQLInputObjectType):
serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()}
return lambda value: ObjectValueNode(
fields=FrozenList(
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
for k, v in value.items()
)
)
if isinstance(arg_type, GraphQLList):
inner_serializer = old_get_arg_serializer(arg_type.of_type)
return partial(serialize_list, inner_serializer)
if isinstance(arg_type, GraphQLEnumType):
return lambda value: EnumValueNode(value=arg_type.serialize(value))
return lambda value: ast_from_value(arg_type.serialize(value), arg_type)


@pytest.fixture
def ds():
client = Client(schema=NestedInputSchema)
ds = DSLSchema(client)
return ds


def test_nested_input_with_old_get_arg_serializer(ds):
dsl.get_arg_serializer = old_get_arg_serializer
with pytest.raises(RecursionError, match="maximum recursion depth exceeded"):
ds.query(ds.Query.foo.args(nested={"foo": 1}))


def test_nested_input_with_new_get_arg_serializer(ds):
dsl.get_arg_serializer = new_get_arg_serializer
assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1}