4
4
from graphql .type .definition import GraphQLObjectType , GraphQLField , GraphQLScalarType
5
5
6
6
7
- class TypeAdaptor (object ):
7
+ class TypeAdapter (object ):
8
8
"""Substitute custom scalars in a GQL response with their decoded counterparts.
9
9
10
10
GQL custom scalar types are defined on the GQL schema and are used to represent
@@ -16,20 +16,20 @@ class TypeAdaptor(object):
16
16
the `_traverse()` function).
17
17
18
18
Each time we find a field which is a custom scalar (it's type name appears
19
- as a key in self.custom_scalars ), we replace the value of that field with the
19
+ as a key in self.custom_types ), we replace the value of that field with the
20
20
decoded value. All of this logic happens in `_substitute()`.
21
21
22
22
Public Interface:
23
23
apply(): pass in a GQL response to replace all instances of custom
24
24
scalar strings with their deserialized representation."""
25
25
26
- def __init__ (self , schema : GraphQLSchema , custom_scalars : Dict [str , Any ] = {}) -> None :
26
+ def __init__ (self , schema : GraphQLSchema , custom_types : Dict [str , Any ] = {}) -> None :
27
27
""" schema: a graphQL schema in the GraphQLSchema format
28
- custom_scalars : a Dict[str, Any],
28
+ custom_types : a Dict[str, Any],
29
29
where str is the name of the custom scalar type, and
30
- Any is a class which has a `parse_value()` function"""
30
+ Any is a class which has a `parse_value(str )` function"""
31
31
self .schema = schema
32
- self .custom_scalars = custom_scalars
32
+ self .custom_types = custom_types
33
33
34
34
def _follow_type_chain (self , node : Any ) -> Any :
35
35
""" Get the type of the schema node in question.
@@ -61,46 +61,49 @@ def _lookup_scalar_type(self, keys: List[str]) -> Optional[str]:
61
61
If keys (e.g. ['film', 'release_date']) points to a scalar type, then
62
62
this function returns the name of that type. (e.g. 'DateTime')
63
63
64
- If it is not a scalar type (e..g a GraphQLObject or list ), then this
64
+ If it is not a scalar type (e..g a GraphQLObject), then this
65
65
function returns None.
66
66
67
67
`keys` is a breadcrumb trail telling us where to look in the GraphQL schema.
68
68
By default the root level is `schema.query`, if that fails, then we check
69
69
`schema.mutation`."""
70
70
71
- def iterate (node : Any , lookup : List [str ]):
72
- lookup = lookup .copy ()
71
+ def traverse_schema (node : Any , lookup : List [str ]):
73
72
if not lookup :
74
73
return self ._get_scalar_type_name (node )
75
74
76
75
final_node = self ._follow_type_chain (node )
77
- return iterate (final_node .fields [lookup .pop (0 )], lookup )
76
+ return traverse_schema (final_node .fields [lookup [0 ]], lookup [1 :])
77
+
78
+ if keys [0 ] in self .schema .get_query_type ().fields :
79
+ schema_root = self .schema .get_query_type ()
80
+ elif keys [0 ] in self .schema .get_mutation_type ().fields :
81
+ schema_root = self .schema .get_mutation_type ()
82
+ else :
83
+ return None
78
84
79
85
try :
80
- return iterate ( self . schema . get_query_type () , keys )
86
+ return traverse_schema ( schema_root , keys )
81
87
except (KeyError , AttributeError ):
82
- try :
83
- return iterate (self .schema .get_mutation_type (), keys )
84
- except (KeyError , AttributeError ):
85
- return None
88
+ return None
86
89
87
- def _substitute (self , keys : List [str ], value : Any ) -> Any :
90
+ def _get_decoded_scalar_type (self , keys : List [str ], value : Any ) -> Any :
88
91
"""Get the decoded value of the type identified by `keys`.
89
92
90
93
If the type is not a custom scalar, then return the original value.
91
94
92
95
If it is a custom scalar, return the deserialized value, as
93
96
output by `<CustomScalarType>.parse_value()`"""
94
97
scalar_type = self ._lookup_scalar_type (keys )
95
- if scalar_type and scalar_type in self .custom_scalars :
96
- return self .custom_scalars [scalar_type ].parse_value (value )
98
+ if scalar_type and scalar_type in self .custom_types :
99
+ return self .custom_types [scalar_type ].parse_value (value )
97
100
return value
98
101
99
- def _traverse (self , response : Dict [str , Any ], substitute : Callable ) -> Dict [str , Any ]:
102
+ def convert_scalars (self , response : Dict [str , Any ]) -> Dict [str , Any ]:
100
103
"""Recursively traverse the GQL response
101
104
102
- Recursively traverses the GQL response and calls the `substitute`
103
- function on all leaf nodes. The function is called with 2 arguments:
105
+ Recursively traverses the GQL response and calls _get_decoded_scalar_type()
106
+ for all leaf nodes. The function is called with 2 arguments:
104
107
keys: List[str] is a breadcrumb trail telling us where we are in the
105
108
response, and therefore, where to look in the GQL Schema.
106
109
value: Any is the value at that node in the response
@@ -109,15 +112,9 @@ def _traverse(self, response: Dict[str, Any], substitute: Callable) -> Dict[str,
109
112
modified."""
110
113
def iterate (node : Any , keys : List [str ] = []):
111
114
if isinstance (node , dict ):
112
- result = {}
113
- for _key , value in node .items ():
114
- result [_key ] = iterate (value , keys + [_key ])
115
- return result
115
+ return {_key : iterate (value , keys + [_key ]) for _key , value in node .items ()}
116
116
elif isinstance (node , list ):
117
117
return [(iterate (item , keys )) for item in node ]
118
118
else :
119
- return substitute (keys , node )
119
+ return self . _get_decoded_scalar_type (keys , node )
120
120
return iterate (response )
121
-
122
- def apply (self , response : Dict [str , Any ]) -> Dict [str , Any ]:
123
- return self ._traverse (response , self ._substitute )
0 commit comments