Skip to content

Commit d76c0fc

Browse files
authored
add discriminator property support (#214)
* add discriminator property support * simply the discriminator logic a bit * remove new category of tests for now * lint * lint * handle a case where there's multiple values mapped to same type
1 parent e7e8b14 commit d76c0fc

File tree

14 files changed

+610
-64
lines changed

14 files changed

+610
-64
lines changed

Diff for: end_to_end_tests/baseline_openapi_3.0.json

+6-3
Original file line numberDiff line numberDiff line change
@@ -2878,7 +2878,8 @@
28782878
"propertyName": "modelType",
28792879
"mapping": {
28802880
"type1": "#/components/schemas/ADiscriminatedUnionType1",
2881-
"type2": "#/components/schemas/ADiscriminatedUnionType2"
2881+
"type2": "#/components/schemas/ADiscriminatedUnionType2",
2882+
"type2-another-value": "#/components/schemas/ADiscriminatedUnionType2"
28822883
}
28832884
},
28842885
"oneOf": [
@@ -2896,15 +2897,17 @@
28962897
"modelType": {
28972898
"type": "string"
28982899
}
2899-
}
2900+
},
2901+
"required": ["modelType"]
29002902
},
29012903
"ADiscriminatedUnionType2": {
29022904
"type": "object",
29032905
"properties": {
29042906
"modelType": {
29052907
"type": "string"
29062908
}
2907-
}
2909+
},
2910+
"required": ["modelType"]
29082911
}
29092912
},
29102913
"parameters": {

Diff for: end_to_end_tests/baseline_openapi_3.1.yaml

+6-3
Original file line numberDiff line numberDiff line change
@@ -2872,7 +2872,8 @@ info:
28722872
"propertyName": "modelType",
28732873
"mapping": {
28742874
"type1": "#/components/schemas/ADiscriminatedUnionType1",
2875-
"type2": "#/components/schemas/ADiscriminatedUnionType2"
2875+
"type2": "#/components/schemas/ADiscriminatedUnionType2",
2876+
"type2-another-value": "#/components/schemas/ADiscriminatedUnionType2"
28762877
}
28772878
},
28782879
"oneOf": [
@@ -2890,15 +2891,17 @@ info:
28902891
"modelType": {
28912892
"type": "string"
28922893
}
2893-
}
2894+
},
2895+
"required": ["modelType"]
28942896
},
28952897
"ADiscriminatedUnionType2": {
28962898
"type": "object",
28972899
"properties": {
28982900
"modelType": {
28992901
"type": "string"
29002902
}
2901-
}
2903+
},
2904+
"required": ["modelType"]
29022905
}
29032906
}
29042907
"parameters": {

Diff for: end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Dict, List, Type, TypeVar
22

33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
from ..types import UNSET, Unset
7-
86
T = TypeVar("T", bound="ADiscriminatedUnionType1")
97

108

119
@_attrs_define
1210
class ADiscriminatedUnionType1:
1311
"""
1412
Attributes:
15-
model_type (Union[Unset, str]):
13+
model_type (str):
1614
"""
1715

18-
model_type: Union[Unset, str] = UNSET
16+
model_type: str
1917
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2018

2119
def to_dict(self) -> Dict[str, Any]:
2220
model_type = self.model_type
2321

2422
field_dict: Dict[str, Any] = {}
2523
field_dict.update(self.additional_properties)
26-
field_dict.update({})
27-
if model_type is not UNSET:
28-
field_dict["modelType"] = model_type
24+
field_dict.update(
25+
{
26+
"modelType": model_type,
27+
}
28+
)
2929

3030
return field_dict
3131

3232
@classmethod
3333
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
3434
d = src_dict.copy()
35-
model_type = d.pop("modelType", UNSET)
35+
model_type = d.pop("modelType")
3636

3737
a_discriminated_union_type_1 = cls(
3838
model_type=model_type,

Diff for: end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Dict, List, Type, TypeVar
22

33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
from ..types import UNSET, Unset
7-
86
T = TypeVar("T", bound="ADiscriminatedUnionType2")
97

108

119
@_attrs_define
1210
class ADiscriminatedUnionType2:
1311
"""
1412
Attributes:
15-
model_type (Union[Unset, str]):
13+
model_type (str):
1614
"""
1715

18-
model_type: Union[Unset, str] = UNSET
16+
model_type: str
1917
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2018

2119
def to_dict(self) -> Dict[str, Any]:
2220
model_type = self.model_type
2321

2422
field_dict: Dict[str, Any] = {}
2523
field_dict.update(self.additional_properties)
26-
field_dict.update({})
27-
if model_type is not UNSET:
28-
field_dict["modelType"] = model_type
24+
field_dict.update(
25+
{
26+
"modelType": model_type,
27+
}
28+
)
2929

3030
return field_dict
3131

3232
@classmethod
3333
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
3434
d = src_dict.copy()
35-
model_type = d.pop("modelType", UNSET)
35+
model_type = d.pop("modelType")
3636

3737
a_discriminated_union_type_2 = cls(
3838
model_type=model_type,

Diff for: end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,42 @@ def _parse_discriminated_union(
5959
return data
6060
if isinstance(data, Unset):
6161
return data
62-
try:
63-
if not isinstance(data, dict):
64-
raise TypeError()
65-
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)
66-
67-
return componentsschemas_a_discriminated_union_type_0
68-
except: # noqa: E722
69-
pass
70-
try:
71-
if not isinstance(data, dict):
72-
raise TypeError()
73-
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
74-
75-
return componentsschemas_a_discriminated_union_type_1
76-
except: # noqa: E722
77-
pass
78-
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
62+
if not isinstance(data, dict):
63+
raise TypeError()
64+
if "modelType" in data:
65+
_discriminator_value = data["modelType"]
66+
67+
def _parse_1(data: object) -> ADiscriminatedUnionType1:
68+
if not isinstance(data, dict):
69+
raise TypeError()
70+
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)
71+
72+
return componentsschemas_a_discriminated_union_type_0
73+
74+
def _parse_2(data: object) -> ADiscriminatedUnionType2:
75+
if not isinstance(data, dict):
76+
raise TypeError()
77+
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
78+
79+
return componentsschemas_a_discriminated_union_type_1
80+
81+
def _parse_3(data: object) -> ADiscriminatedUnionType2:
82+
if not isinstance(data, dict):
83+
raise TypeError()
84+
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
85+
86+
return componentsschemas_a_discriminated_union_type_1
87+
88+
_discriminator_mapping = {
89+
"type1": _parse_1,
90+
"type2": _parse_2,
91+
"type2-another-value": _parse_3,
92+
}
93+
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
94+
return cast(
95+
Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data)
96+
)
97+
raise TypeError("unrecognized value for property modelType")
7998

8099
discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))
81100

Diff for: openapi_python_client/parser/properties/model_property.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ModelProperty(PropertyProtocol):
3232
relative_imports: set[str] | None
3333
lazy_imports: set[str] | None
3434
additional_properties: Property | None
35+
ref_path: ReferencePath | None = None
3536
_json_type_string: ClassVar[str] = "Dict[str, Any]"
3637

3738
template: ClassVar[str] = "model_property.py.jinja"

Diff for: openapi_python_client/parser/properties/protocol.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from openapi_python_client.parser.properties.schemas import ReferencePath
4+
35
__all__ = ["PropertyProtocol", "Value"]
46

57
from abc import abstractmethod
@@ -185,3 +187,6 @@ def is_base_type(self) -> bool:
185187
ListProperty.__name__,
186188
UnionProperty.__name__,
187189
}
190+
191+
def get_ref_path(self) -> ReferencePath | None:
192+
return self.ref_path if hasattr(self, "ref_path") else None

Diff for: openapi_python_client/parser/properties/schemas.py

+9
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ def update_schemas_with_data(
142142
)
143143
return prop
144144

145+
# Save the original path (/components/schemas/X) in the property. This is important because:
146+
# 1. There are some contexts (such as a union with a discriminator) where we have a Property
147+
# instance and we want to know what its path is, instead of the other way round.
148+
# 2. Even though we did set prop.name to be the same as ref_path when we created it above,
149+
# whenever there's a $ref to this property, we end up making a copy of it and changing
150+
# the name. So we can't rely on prop.name always being the path.
151+
if hasattr(prop, "ref_path"):
152+
prop.ref_path = ref_path
153+
145154
schemas = evolve(schemas, classes_by_reference={ref_path: prop, **schemas.classes_by_reference})
146155
return schemas
147156

0 commit comments

Comments
 (0)