Skip to content

correctly resolve references to a type that is itself just a single allOf reference #1103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
default: patch
---

# Correctly resolve references to a type that is itself just a single allOf reference

PR #1103 fixed issue #1091. Thanks @eli-bl!
44 changes: 44 additions & 0 deletions end_to_end_tests/baseline_openapi_3.0.json
Original file line number Diff line number Diff line change
@@ -1629,6 +1629,33 @@
}
}
}
},
"/models/allof": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"aliased": {
"$ref": "#/components/schemas/Aliased"
},
"extended": {
"$ref": "#/components/schemas/Extended"
},
"model": {
"$ref": "#/components/schemas/AModel"
}
}
}
}
}
}
}
}
}
},
"components": {
@@ -1647,6 +1674,23 @@
"an_required_field"
]
},
"Aliased":{
"allOf": [
{"$ref": "#/components/schemas/AModel"}
]
},
"Extended": {
"allOf": [
{"$ref": "#/components/schemas/Aliased"},
{"type": "object",
"properties": {
"fromExtended": {
"type": "string"
}
}
}
]
},
"AModel": {
"title": "AModel",
"required": [
64 changes: 48 additions & 16 deletions end_to_end_tests/baseline_openapi_3.1.yaml
Original file line number Diff line number Diff line change
@@ -1619,7 +1619,34 @@ info:
}
}
}
}
},
"/models/allof": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"aliased": {
"$ref": "#/components/schemas/Aliased"
},
"extended": {
"$ref": "#/components/schemas/Extended"
},
"model": {
"$ref": "#/components/schemas/AModel"
}
}
}
}
}
}
}
}
},
}
"components":
"schemas": {
@@ -1637,6 +1664,23 @@ info:
"an_required_field"
]
},
"Aliased": {
"allOf": [
{ "$ref": "#/components/schemas/AModel" }
]
},
"Extended": {
"allOf": [
{ "$ref": "#/components/schemas/Aliased" },
{ "type": "object",
"properties": {
"fromExtended": {
"type": "string"
}
}
}
]
},
"AModel": {
"title": "AModel",
"required": [
@@ -1667,11 +1711,7 @@ info:
"default": "overridden_default"
},
"an_optional_allof_enum": {
"allOf": [
{
"$ref": "#/components/schemas/AnAllOfEnum"
}
]
"$ref": "#/components/schemas/AnAllOfEnum",
},
"nested_list_of_enums": {
"title": "Nested List Of Enums",
@@ -1808,11 +1848,7 @@ info:
]
},
"model": {
"allOf": [
{
"$ref": "#/components/schemas/ModelWithUnionProperty"
}
]
"$ref": "#/components/schemas/ModelWithUnionProperty"
},
"nullable_model": {
"oneOf": [
@@ -1825,11 +1861,7 @@ info:
]
},
"not_required_model": {
"allOf": [
{
"$ref": "#/components/schemas/ModelWithUnionProperty"
}
]
"$ref": "#/components/schemas/ModelWithUnionProperty"
},
"not_required_nullable_model": {
"oneOf": [
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import types

from . import get_common_parameters, post_common_parameters, reserved_parameters
from . import get_common_parameters, get_models_allof, post_common_parameters, reserved_parameters


class DefaultEndpoints:
@@ -17,3 +17,7 @@ def post_common_parameters(cls) -> types.ModuleType:
@classmethod
def reserved_parameters(cls) -> types.ModuleType:
return reserved_parameters

@classmethod
def get_models_allof(cls) -> types.ModuleType:
return get_models_allof
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from http import HTTPStatus
from typing import Any, Dict, Optional, Union

import httpx

from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.get_models_allof_response_200 import GetModelsAllofResponse200
from ...types import Response


def _get_kwargs() -> Dict[str, Any]:
_kwargs: Dict[str, Any] = {
"method": "get",
"url": "/models/allof",
}

return _kwargs


def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[GetModelsAllofResponse200]:
if response.status_code == HTTPStatus.OK:
response_200 = GetModelsAllofResponse200.from_dict(response.json())

return response_200
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None


def _build_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Response[GetModelsAllofResponse200]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)


def sync_detailed(
*,
client: Union[AuthenticatedClient, Client],
) -> Response[GetModelsAllofResponse200]:
"""
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[GetModelsAllofResponse200]
"""

kwargs = _get_kwargs()

response = client.get_httpx_client().request(
**kwargs,
)

return _build_response(client=client, response=response)


def sync(
*,
client: Union[AuthenticatedClient, Client],
) -> Optional[GetModelsAllofResponse200]:
"""
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
GetModelsAllofResponse200
"""

return sync_detailed(
client=client,
).parsed


async def asyncio_detailed(
*,
client: Union[AuthenticatedClient, Client],
) -> Response[GetModelsAllofResponse200]:
"""
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[GetModelsAllofResponse200]
"""

kwargs = _get_kwargs()

response = await client.get_async_httpx_client().request(**kwargs)

return _build_response(client=client, response=response)


async def asyncio(
*,
client: Union[AuthenticatedClient, Client],
) -> Optional[GetModelsAllofResponse200]:
"""
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
GetModelsAllofResponse200
"""

return (
await asyncio_detailed(
client=client,
)
).parsed
Original file line number Diff line number Diff line change
@@ -34,9 +34,11 @@
from .body_upload_file_tests_upload_post_some_object import BodyUploadFileTestsUploadPostSomeObject
from .body_upload_file_tests_upload_post_some_optional_object import BodyUploadFileTestsUploadPostSomeOptionalObject
from .different_enum import DifferentEnum
from .extended import Extended
from .free_form_model import FreeFormModel
from .get_location_header_types_int_enum_header import GetLocationHeaderTypesIntEnumHeader
from .get_location_header_types_string_enum_header import GetLocationHeaderTypesStringEnumHeader
from .get_models_allof_response_200 import GetModelsAllofResponse200
from .http_validation_error import HTTPValidationError
from .import_ import Import
from .json_like_body import JsonLikeBody
@@ -111,9 +113,11 @@
"BodyUploadFileTestsUploadPostSomeObject",
"BodyUploadFileTestsUploadPostSomeOptionalObject",
"DifferentEnum",
"Extended",
"FreeFormModel",
"GetLocationHeaderTypesIntEnumHeader",
"GetLocationHeaderTypesStringEnumHeader",
"GetModelsAllofResponse200",
"HTTPValidationError",
"Import",
"JsonLikeBody",
514 changes: 514 additions & 0 deletions end_to_end_tests/golden-record/my_test_api_client/models/extended.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

if TYPE_CHECKING:
from ..models.a_model import AModel
from ..models.extended import Extended


T = TypeVar("T", bound="GetModelsAllofResponse200")


@_attrs_define
class GetModelsAllofResponse200:
"""
Attributes:
aliased (Union[Unset, AModel]): A Model for testing all the ways custom objects can be used
extended (Union[Unset, Extended]):
model (Union[Unset, AModel]): A Model for testing all the ways custom objects can be used
"""

aliased: Union[Unset, "AModel"] = UNSET
extended: Union[Unset, "Extended"] = UNSET
model: Union[Unset, "AModel"] = UNSET
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
aliased: Union[Unset, Dict[str, Any]] = UNSET
if not isinstance(self.aliased, Unset):
aliased = self.aliased.to_dict()

extended: Union[Unset, Dict[str, Any]] = UNSET
if not isinstance(self.extended, Unset):
extended = self.extended.to_dict()

model: Union[Unset, Dict[str, Any]] = UNSET
if not isinstance(self.model, Unset):
model = self.model.to_dict()

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if aliased is not UNSET:
field_dict["aliased"] = aliased
if extended is not UNSET:
field_dict["extended"] = extended
if model is not UNSET:
field_dict["model"] = model

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
from ..models.a_model import AModel
from ..models.extended import Extended

d = src_dict.copy()
_aliased = d.pop("aliased", UNSET)
aliased: Union[Unset, AModel]
if isinstance(_aliased, Unset):
aliased = UNSET
else:
aliased = AModel.from_dict(_aliased)

_extended = d.pop("extended", UNSET)
extended: Union[Unset, Extended]
if isinstance(_extended, Unset):
extended = UNSET
else:
extended = Extended.from_dict(_extended)

_model = d.pop("model", UNSET)
model: Union[Unset, AModel]
if isinstance(_model, Unset):
model = UNSET
else:
model = AModel.from_dict(_model)

get_models_allof_response_200 = cls(
aliased=aliased,
extended=extended,
model=model,
)

get_models_allof_response_200.additional_properties = d
return get_models_allof_response_200

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
1 change: 1 addition & 0 deletions openapi_python_client/parser/bodies.py
Original file line number Diff line number Diff line change
@@ -117,6 +117,7 @@ def body_from_data(
**schemas.classes_by_name,
prop.class_info.name: prop,
},
models_to_process=[*schemas.models_to_process, prop],
)
bodies.append(
Body(
21 changes: 15 additions & 6 deletions openapi_python_client/parser/properties/__init__.py
Original file line number Diff line number Diff line change
@@ -126,7 +126,7 @@ def _property_from_ref(
return prop, schemas


def property_from_data( # noqa: PLR0911
def property_from_data( # noqa: PLR0911, PLR0912
name: str,
required: bool,
data: oai.Reference | oai.Schema,
@@ -153,7 +153,7 @@ def property_from_data( # noqa: PLR0911
sub_data: list[oai.Schema | oai.Reference] = data.allOf + data.anyOf + data.oneOf
# A union of a single reference should just be passed through to that reference (don't create copy class)
if len(sub_data) == 1 and isinstance(sub_data[0], oai.Reference):
return _property_from_ref(
prop, schemas = _property_from_ref(
name=name,
required=required,
parent=data,
@@ -162,6 +162,16 @@ def property_from_data( # noqa: PLR0911
config=config,
roots=roots,
)
# We won't be generating a separate Python class for this schema - references to it will just use
# the class for the schema it's referencing - so we don't add it to classes_by_name; but we do
# add it to models_to_process, if it's a model, because its properties still need to be resolved.
if isinstance(prop, ModelProperty):
schemas = evolve(
schemas,
models_to_process=[*schemas.models_to_process, prop],
)
return prop, schemas

if data.type == oai.DataType.BOOLEAN:
return (
BooleanProperty.build(
@@ -341,7 +351,7 @@ def _process_model_errors(


def _process_models(*, schemas: Schemas, config: Config) -> Schemas:
to_process = (prop for prop in schemas.classes_by_name.values() if isinstance(prop, ModelProperty))
to_process = schemas.models_to_process
still_making_progress = True
final_model_errors: list[tuple[ModelProperty, PropertyError]] = []
latest_model_errors: list[tuple[ModelProperty, PropertyError]] = []
@@ -368,12 +378,11 @@ def _process_models(*, schemas: Schemas, config: Config) -> Schemas:
continue
schemas = schemas_or_err
still_making_progress = True
to_process = (prop for prop in next_round)
to_process = next_round

final_model_errors.extend(latest_model_errors)
errors = _process_model_errors(final_model_errors, schemas=schemas)
schemas.errors.extend(errors)
return schemas
return evolve(schemas, errors=[*schemas.errors, *errors], models_to_process=to_process)


def build_schemas(
6 changes: 5 additions & 1 deletion openapi_python_client/parser/properties/model_property.py
Original file line number Diff line number Diff line change
@@ -119,7 +119,11 @@ def build(
)
return error, schemas

schemas = evolve(schemas, classes_by_name={**schemas.classes_by_name, class_info.name: prop})
schemas = evolve(
schemas,
classes_by_name={**schemas.classes_by_name, class_info.name: prop},
models_to_process=[*schemas.models_to_process, prop],
)
return prop, schemas

@classmethod
3 changes: 3 additions & 0 deletions openapi_python_client/parser/properties/schemas.py
Original file line number Diff line number Diff line change
@@ -22,8 +22,10 @@
from ..errors import ParameterError, ParseError, PropertyError

if TYPE_CHECKING: # pragma: no cover
from .model_property import ModelProperty
from .property import Property
else:
ModelProperty = "ModelProperty"
Property = "Property"


@@ -77,6 +79,7 @@ class Schemas:
classes_by_reference: Dict[ReferencePath, Property] = field(factory=dict)
dependencies: Dict[ReferencePath, Set[Union[ReferencePath, ClassName]]] = field(factory=dict)
classes_by_name: Dict[ClassName, Property] = field(factory=dict)
models_to_process: List[ModelProperty] = field(factory=list)
errors: List[ParseError] = field(factory=list)

def add_dependencies(self, ref_path: ReferencePath, roots: Set[Union[ReferencePath, ClassName]]) -> None:
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ composite = ["test --cov openapi_python_client tests --cov-report=term-missing"]

[tool.pdm.scripts.regen_integration]
shell = """
openapi-python-client update --url https://raw.githubusercontent.com/openapi-generators/openapi-test-server/main/openapi.json --config integration-tests/config.yaml --meta pdm \
openapi-python-client generate --overwrite --url https://raw.githubusercontent.com/openapi-generators/openapi-test-server/main/openapi.json --config integration-tests/config.yaml --meta none --output-path integration-tests/integration_tests \
"""

[build-system]
92 changes: 58 additions & 34 deletions tests/test_parser/test_properties/test_init.py
Original file line number Diff line number Diff line change
@@ -530,6 +530,7 @@ def test_property_from_data_ref_enum_with_overridden_default(self, enum_property
prop, new_schemas = property_from_data(
name=name, required=required, data=data, schemas=schemas, parent_name="", config=config
)
new_schemas = attr.evolve(new_schemas, models_to_process=[]) # intermediate state irrelevant to this test

assert prop == enum_property_factory(
name="some_enum",
@@ -911,37 +912,6 @@ def test_retries_failing_properties_while_making_progress(self, mocker, config):


class TestProcessModels:
def test_retries_failing_models_while_making_progress(
self, mocker, model_property_factory, any_property_factory, config
):
from openapi_python_client.parser.properties import _process_models

first_model = model_property_factory()
second_class_name = ClassName("second", "")
schemas = Schemas(
classes_by_name={
ClassName("first", ""): first_model,
second_class_name: model_property_factory(),
ClassName("non-model", ""): any_property_factory(),
}
)
process_model = mocker.patch(
f"{MODULE_NAME}.process_model", side_effect=[PropertyError(), Schemas(), PropertyError()]
)
process_model_errors = mocker.patch(f"{MODULE_NAME}._process_model_errors", return_value=["error"])

result = _process_models(schemas=schemas, config=config)

process_model.assert_has_calls(
[
call(first_model, schemas=schemas, config=config),
call(schemas.classes_by_name[second_class_name], schemas=schemas, config=config),
call(first_model, schemas=result, config=config),
]
)
assert process_model_errors.was_called_once_with([(first_model, PropertyError())])
assert all(error in result.errors for error in process_model_errors.return_value)

def test_detect_recursive_allof_reference_no_retry(self, mocker, model_property_factory, config):
from openapi_python_client.parser.properties import Class, _process_models
from openapi_python_client.schema import Reference
@@ -950,14 +920,16 @@ def test_detect_recursive_allof_reference_no_retry(self, mocker, model_property_
recursive_model = model_property_factory(
class_info=Class(name=class_name, module_name=PythonIdentifier("module_name", ""))
)
second_model = model_property_factory()
schemas = Schemas(
classes_by_name={
"recursive": recursive_model,
"second": model_property_factory(),
}
"second": second_model,
},
models_to_process=[recursive_model, second_model],
)
recursion_error = PropertyError(data=Reference.model_construct(ref=f"#/{class_name}"))
process_model = mocker.patch(f"{MODULE_NAME}.process_model", side_effect=[recursion_error, Schemas()])
process_model = mocker.patch(f"{MODULE_NAME}.process_model", side_effect=[recursion_error, schemas])
process_model_errors = mocker.patch(f"{MODULE_NAME}._process_model_errors", return_value=["error"])

result = _process_models(schemas=schemas, config=config)
@@ -972,6 +944,58 @@ def test_detect_recursive_allof_reference_no_retry(self, mocker, model_property_
assert all(error in result.errors for error in process_model_errors.return_value)
assert "\n\nRecursive allOf reference found" in recursion_error.detail

def test_resolve_reference_to_single_allof_reference(self, config, model_property_factory):
# test for https://github.com/openapi-generators/openapi-python-client/issues/1091
from openapi_python_client.parser.properties import Schemas, build_schemas

components = {
"Model1": oai.Schema.model_construct(
type="object",
properties={
"prop1": oai.Schema.model_construct(type="string"),
},
),
"Model2": oai.Schema.model_construct(
allOf=[
oai.Reference.model_construct(ref="#/components/schemas/Model1"),
]
),
"Model3": oai.Schema.model_construct(
allOf=[
oai.Reference.model_construct(ref="#/components/schemas/Model2"),
oai.Schema.model_construct(
type="object",
properties={
"prop2": oai.Schema.model_construct(type="string"),
},
),
],
),
}
schemas = Schemas()

result = build_schemas(components=components, schemas=schemas, config=config)

assert result.errors == []
assert result.models_to_process == []

# Classes should only be generated for Model1 and Model3
assert result.classes_by_name.keys() == {"Model1", "Model3"}

# References to Model2 should be resolved to the same class as Model1
assert result.classes_by_reference.keys() == {
"/components/schemas/Model1",
"/components/schemas/Model2",
"/components/schemas/Model3",
}
assert (
result.classes_by_reference["/components/schemas/Model2"].class_info
== result.classes_by_reference["/components/schemas/Model1"].class_info
)

# Verify that Model3 extended the properties from Model1
assert [p.name for p in result.classes_by_name["Model3"].optional_properties] == ["prop1", "prop2"]


class TestPropogateRemoval:
def test_propogate_removal_class_name(self):