1
1
from __future__ import annotations
2
2
3
3
from itertools import chain
4
- from typing import Any , ClassVar , OrderedDict , cast
4
+ from typing import Any , ClassVar , Mapping , OrderedDict , cast
5
5
6
6
from attr import define , evolve
7
7
16
16
@define
17
17
class DiscriminatorDefinition :
18
18
"""Represents a discriminator that can optionally be specified for a union type.
19
-
19
+
20
20
Normally, a UnionProperty has either zero or one of these. However, a nested union
21
21
could have more than one, as we accumulate all the discriminators when we flatten
22
22
out the nested schemas. For example:
@@ -36,8 +36,9 @@ class DiscriminatorDefinition:
36
36
In this example there are four schemas and two discriminators. The deserializer
37
37
logic will check for the mammalType property first, then birdType.
38
38
"""
39
+
39
40
property_name : str
40
- value_to_model_map : dict [str , PropertyProtocol ]
41
+ value_to_model_map : Mapping [str , PropertyProtocol ]
41
42
# Every value in the map is really a ModelProperty, but this avoids circular imports
42
43
43
44
@@ -260,7 +261,7 @@ def _parse_discriminator(
260
261
# mapping:
261
262
# value_for_a: "#/components/schemas/ModelA"
262
263
# value_for_b: ModelB # equivalent to "#/components/schemas/ModelB"
263
- #
264
+ #
264
265
# For any type that isn't specified in the mapping (or if the whole mapping is omitted)
265
266
# the default lookup value for each schema is the same as the schema name. So this--
266
267
# mapping:
@@ -275,7 +276,7 @@ def _parse_discriminator(
275
276
def _get_model_name (model : ModelProperty ) -> str | None :
276
277
return get_reference_simple_name (model .ref_path ) if model .ref_path else None
277
278
278
- model_types_by_name : dict [str , PropertyProtocol ] = {}
279
+ model_types_by_name : dict [str , ModelProperty ] = {}
279
280
for model in subtypes :
280
281
# Note, model here can never be a UnionProperty, because we've already done
281
282
# flatten_union_properties() before this point.
@@ -290,7 +291,7 @@ def _get_model_name(model: ModelProperty) -> str | None:
290
291
)
291
292
model_types_by_name [name ] = model
292
293
293
- mapping : dict [str , PropertyProtocol ] = OrderedDict () # use ordered dict for test determinacy
294
+ mapping : dict [str , ModelProperty ] = OrderedDict () # use ordered dict for test determinacy
294
295
unspecified_models = list (model_types_by_name .values ())
295
296
if data .mapping :
296
297
for discriminator_value , model_ref in data .mapping .items ():
@@ -301,13 +302,13 @@ def _get_model_name(model: ModelProperty) -> str | None:
301
302
name = get_reference_simple_name (ref_path )
302
303
else :
303
304
name = model_ref
304
- model = model_types_by_name .get (name )
305
- if not model :
305
+ mapped_model = model_types_by_name .get (name )
306
+ if not mapped_model :
306
307
return PropertyError (
307
308
detail = f'Discriminator mapping referred to "{ name } " which is not one of the schema variants' ,
308
309
)
309
- mapping [discriminator_value ] = model
310
- unspecified_models .remove (model )
310
+ mapping [discriminator_value ] = mapped_model
311
+ unspecified_models .remove (mapped_model )
311
312
for model in unspecified_models :
312
313
if name := _get_model_name (model ):
313
314
mapping [name ] = model
0 commit comments