Skip to content

Commit 67c8976

Browse files
author
xzhang2
committed
Fixed the bug where a nested GraphQLInputObjectType causing infinite recursive calls to get_arg_serializer.
1 parent ba18b5e commit 67c8976

File tree

4 files changed

+104
-7
lines changed

4 files changed

+104
-7
lines changed

gql/dsl.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def args(self, **kwargs):
105105
arg = self.field.args.get(name)
106106
if not arg:
107107
raise KeyError(f"Argument {name} does not exist in {self.field}.")
108-
arg_type_serializer = get_arg_serializer(arg.type)
108+
arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict())
109109
serialized_value = arg_type_serializer(value)
110110
added_args.append(
111111
ArgumentNode(name=NameNode(value=name), value=serialized_value)
@@ -151,21 +151,25 @@ def serialize_list(serializer, list_values):
151151
return ListValueNode(values=FrozenList(serializer(v) for v in list_values))
152152

153153

154-
def get_arg_serializer(arg_type):
154+
def get_arg_serializer(arg_type, known_serializers):
155155
if isinstance(arg_type, GraphQLNonNull):
156-
return get_arg_serializer(arg_type.of_type)
156+
return get_arg_serializer(arg_type.of_type, known_serializers)
157157
if isinstance(arg_type, GraphQLInputField):
158-
return get_arg_serializer(arg_type.type)
158+
return get_arg_serializer(arg_type.type, known_serializers)
159159
if isinstance(arg_type, GraphQLInputObjectType):
160-
serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()}
161-
return lambda value: ObjectValueNode(
160+
if arg_type in known_serializers:
161+
return known_serializers[arg_type]
162+
known_serializers[arg_type] = None
163+
serializers = {k: get_arg_serializer(v, known_serializers) for k, v in arg_type.fields.items()}
164+
known_serializers[arg_type] = lambda value: ObjectValueNode(
162165
fields=FrozenList(
163166
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
164167
for k, v in value.items()
165168
)
166169
)
170+
return known_serializers[arg_type]
167171
if isinstance(arg_type, GraphQLList):
168-
inner_serializer = get_arg_serializer(arg_type.of_type)
172+
inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers)
169173
return partial(serialize_list, inner_serializer)
170174
if isinstance(arg_type, GraphQLEnumType):
171175
return lambda value: EnumValueNode(value=arg_type.serialize(value))

tests/nested_input/__init__.py

Whitespace-only changes.

tests/nested_input/schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from graphql import (
2+
GraphQLArgument,
3+
GraphQLField,
4+
GraphQLInputField,
5+
GraphQLInputObjectType,
6+
GraphQLInt,
7+
GraphQLObjectType,
8+
GraphQLSchema,
9+
)
10+
11+
nestedInput = GraphQLInputObjectType(
12+
"Nested",
13+
description="The input object that has a field pointing to itself",
14+
fields={"foo": GraphQLInputField(GraphQLInt, description="foo")},
15+
)
16+
17+
nestedInput.fields["child"] = GraphQLInputField(nestedInput, description="child")
18+
19+
queryType = GraphQLObjectType(
20+
"Query",
21+
fields=lambda: {
22+
"foo": GraphQLField(
23+
args={"nested": GraphQLArgument(type_=nestedInput)},
24+
resolve=lambda *args, **kwargs: 1,
25+
type_=GraphQLInt,
26+
),
27+
},
28+
)
29+
30+
NestedInputSchema = GraphQLSchema(query=queryType, types=[nestedInput],)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from functools import partial
2+
3+
import pytest
4+
from graphql import (
5+
EnumValueNode,
6+
GraphQLEnumType,
7+
GraphQLInputField,
8+
GraphQLInputObjectType,
9+
GraphQLList,
10+
GraphQLNonNull,
11+
NameNode,
12+
ObjectFieldNode,
13+
ObjectValueNode,
14+
ast_from_value,
15+
)
16+
from graphql.pyutils import FrozenList
17+
18+
import gql.dsl as dsl
19+
from gql import Client
20+
from gql.dsl import DSLSchema, serialize_list
21+
from tests.nested_input.schema import NestedInputSchema
22+
23+
# back up the new func
24+
new_get_arg_serializer = dsl.get_arg_serializer
25+
26+
27+
def old_get_arg_serializer(arg_type, known_serializers=None):
28+
if isinstance(arg_type, GraphQLNonNull):
29+
return old_get_arg_serializer(arg_type.of_type)
30+
if isinstance(arg_type, GraphQLInputField):
31+
return old_get_arg_serializer(arg_type.type)
32+
if isinstance(arg_type, GraphQLInputObjectType):
33+
serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()}
34+
return lambda value: ObjectValueNode(
35+
fields=FrozenList(
36+
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
37+
for k, v in value.items()
38+
)
39+
)
40+
if isinstance(arg_type, GraphQLList):
41+
inner_serializer = old_get_arg_serializer(arg_type.of_type)
42+
return partial(serialize_list, inner_serializer)
43+
if isinstance(arg_type, GraphQLEnumType):
44+
return lambda value: EnumValueNode(value=arg_type.serialize(value))
45+
return lambda value: ast_from_value(arg_type.serialize(value), arg_type)
46+
47+
48+
@pytest.fixture
49+
def ds():
50+
client = Client(schema=NestedInputSchema)
51+
ds = DSLSchema(client)
52+
return ds
53+
54+
55+
def test_nested_input_with_old_get_arg_serializer(ds):
56+
dsl.get_arg_serializer = old_get_arg_serializer
57+
with pytest.raises(RecursionError, match="maximum recursion depth exceeded"):
58+
ds.query(ds.Query.foo.args(nested={"foo": 1}))
59+
60+
61+
def test_nested_input_with_new_get_arg_serializer(ds):
62+
dsl.get_arg_serializer = new_get_arg_serializer
63+
assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1}

0 commit comments

Comments
 (0)