Skip to content

Commit ceb46b5

Browse files
author
xzhang2
committed
Fixed the bug where a nested GraphQLInputObjectType causing infinite recursive calls to get_arg_serializer.
1 parent ba18b5e commit ceb46b5

File tree

4 files changed

+105
-7
lines changed

4 files changed

+105
-7
lines changed

gql/dsl.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Iterable
22
from functools import partial
3+
from typing import Any, Callable, Dict
34

45
from graphql import (
56
ArgumentNode,
@@ -105,7 +106,7 @@ def args(self, **kwargs):
105106
arg = self.field.args.get(name)
106107
if not arg:
107108
raise KeyError(f"Argument {name} does not exist in {self.field}.")
108-
arg_type_serializer = get_arg_serializer(arg.type)
109+
arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict())
109110
serialized_value = arg_type_serializer(value)
110111
added_args.append(
111112
ArgumentNode(name=NameNode(value=name), value=serialized_value)
@@ -151,21 +152,25 @@ def serialize_list(serializer, list_values):
151152
return ListValueNode(values=FrozenList(serializer(v) for v in list_values))
152153

153154

154-
def get_arg_serializer(arg_type):
155+
def get_arg_serializer(arg_type, known_serializers):
155156
if isinstance(arg_type, GraphQLNonNull):
156-
return get_arg_serializer(arg_type.of_type)
157+
return get_arg_serializer(arg_type.of_type, known_serializers)
157158
if isinstance(arg_type, GraphQLInputField):
158-
return get_arg_serializer(arg_type.type)
159+
return get_arg_serializer(arg_type.type, known_serializers)
159160
if isinstance(arg_type, GraphQLInputObjectType):
160-
serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()}
161-
return lambda value: ObjectValueNode(
161+
if arg_type in known_serializers:
162+
return known_serializers[arg_type]
163+
known_serializers[arg_type] = None
164+
serializers = {k: get_arg_serializer(v, known_serializers) for k, v in arg_type.fields.items()}
165+
known_serializers[arg_type] = lambda value: ObjectValueNode(
162166
fields=FrozenList(
163167
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
164168
for k, v in value.items()
165169
)
166170
)
171+
return known_serializers[arg_type]
167172
if isinstance(arg_type, GraphQLList):
168-
inner_serializer = get_arg_serializer(arg_type.of_type)
173+
inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers)
169174
return partial(serialize_list, inner_serializer)
170175
if isinstance(arg_type, GraphQLEnumType):
171176
return lambda value: EnumValueNode(value=arg_type.serialize(value))

tests/nested_input/__init__.py

Whitespace-only changes.

tests/nested_input/schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from graphql import (
2+
GraphQLArgument,
3+
GraphQLField,
4+
GraphQLInputField,
5+
GraphQLInputObjectType,
6+
GraphQLInt,
7+
GraphQLObjectType,
8+
GraphQLSchema,
9+
)
10+
11+
nestedInput = GraphQLInputObjectType(
12+
"Nested",
13+
description="The input object that has a field pointing to itself",
14+
fields={"foo": GraphQLInputField(GraphQLInt, description="foo")},
15+
)
16+
17+
nestedInput.fields["child"] = GraphQLInputField(nestedInput, description="child")
18+
19+
queryType = GraphQLObjectType(
20+
"Query",
21+
fields=lambda: {
22+
"foo": GraphQLField(
23+
args={"nested": GraphQLArgument(type_=nestedInput)},
24+
resolve=lambda *args, **kwargs: 1,
25+
type_=GraphQLInt,
26+
),
27+
},
28+
)
29+
30+
NestedInputSchema = GraphQLSchema(query=queryType, types=[nestedInput],)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from functools import partial
2+
3+
import pytest
4+
from graphql import (
5+
EnumValueNode,
6+
GraphQLEnumType,
7+
GraphQLInputField,
8+
GraphQLInputObjectType,
9+
GraphQLList,
10+
GraphQLNonNull,
11+
NameNode,
12+
ObjectFieldNode,
13+
ObjectValueNode,
14+
ast_from_value,
15+
)
16+
from graphql.pyutils import FrozenList
17+
18+
import gql.dsl as dsl
19+
from gql import Client
20+
from gql.dsl import DSLSchema, serialize_list
21+
from tests.nested_input.schema import NestedInputSchema
22+
23+
# back up the new func
24+
new_get_arg_serializer = dsl.get_arg_serializer
25+
26+
27+
def old_get_arg_serializer(arg_type, known_serializers=None):
28+
if isinstance(arg_type, GraphQLNonNull):
29+
return old_get_arg_serializer(arg_type.of_type)
30+
if isinstance(arg_type, GraphQLInputField):
31+
return old_get_arg_serializer(arg_type.type)
32+
if isinstance(arg_type, GraphQLInputObjectType):
33+
serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()}
34+
return lambda value: ObjectValueNode(
35+
fields=FrozenList(
36+
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
37+
for k, v in value.items()
38+
)
39+
)
40+
if isinstance(arg_type, GraphQLList):
41+
inner_serializer = old_get_arg_serializer(arg_type.of_type)
42+
return partial(serialize_list, inner_serializer)
43+
if isinstance(arg_type, GraphQLEnumType):
44+
return lambda value: EnumValueNode(value=arg_type.serialize(value))
45+
return lambda value: ast_from_value(arg_type.serialize(value), arg_type)
46+
47+
48+
@pytest.fixture
49+
def ds():
50+
client = Client(schema=NestedInputSchema)
51+
ds = DSLSchema(client)
52+
return ds
53+
54+
55+
def test_nested_input_with_old_get_arg_serializer(ds):
56+
dsl.get_arg_serializer = old_get_arg_serializer
57+
with pytest.raises(RecursionError, match="maximum recursion depth exceeded"):
58+
ds.query(ds.Query.foo.args(nested={"foo": 1}))
59+
60+
61+
def test_nested_input_with_new_get_arg_serializer(ds):
62+
dsl.get_arg_serializer = new_get_arg_serializer
63+
assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1}

0 commit comments

Comments
 (0)