diff --git a/end_to_end_tests/custom-templates-golden-record/README.md b/end_to_end_tests/custom-templates-golden-record/README.md new file mode 100644 index 000000000..e5106eea7 --- /dev/null +++ b/end_to_end_tests/custom-templates-golden-record/README.md @@ -0,0 +1 @@ +my-test-api-client \ No newline at end of file diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/__init__.py new file mode 100644 index 000000000..3ee5dbaf0 --- /dev/null +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/__init__.py @@ -0,0 +1,21 @@ +""" Contains methods for accessing the API """ + +from typing import Type + +from my_test_api_client.api.default import DefaultEndpoints +from my_test_api_client.api.parameters import ParametersEndpoints +from my_test_api_client.api.tests import TestsEndpoints + + +class MyTestApiClientApi: + @classmethod + def tests(cls) -> Type[TestsEndpoints]: + return TestsEndpoints + + @classmethod + def default(cls) -> Type[DefaultEndpoints]: + return DefaultEndpoints + + @classmethod + def parameters(cls) -> Type[ParametersEndpoints]: + return ParametersEndpoints diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py new file mode 100644 index 000000000..4d0eb4fb5 --- /dev/null +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py @@ -0,0 +1,15 @@ +""" Contains methods for accessing the API Endpoints """ + +import types + +from my_test_api_client.api.default import get_common_parameters, post_common_parameters + + +class DefaultEndpoints: + @classmethod + def get_common_parameters(cls) -> types.ModuleType: + return get_common_parameters + + @classmethod + def post_common_parameters(cls) -> types.ModuleType: + return post_common_parameters diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/parameters/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/parameters/__init__.py new file mode 100644 index 000000000..b92c6d96b --- /dev/null +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/parameters/__init__.py @@ -0,0 +1,11 @@ +""" Contains methods for accessing the API Endpoints """ + +import types + +from my_test_api_client.api.parameters import get_same_name_multiple_locations_param + + +class ParametersEndpoints: + @classmethod + def get_same_name_multiple_locations_param(cls) -> types.ModuleType: + return get_same_name_multiple_locations_param diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/tests/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/tests/__init__.py new file mode 100644 index 000000000..dcb864fe9 --- /dev/null +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/tests/__init__.py @@ -0,0 +1,136 @@ +""" Contains methods for accessing the API Endpoints """ + +import types + +from my_test_api_client.api.tests import ( + defaults_tests_defaults_post, + get_basic_list_of_booleans, + get_basic_list_of_floats, + get_basic_list_of_integers, + get_basic_list_of_strings, + get_user_list, + int_enum_tests_int_enum_post, + json_body_tests_json_body_post, + no_response_tests_no_response_get, + octet_stream_tests_octet_stream_get, + optional_value_tests_optional_query_param, + post_form_data, + test_inline_objects, + token_with_cookie_auth_token_with_cookie_get, + unsupported_content_tests_unsupported_content_get, + upload_file_tests_upload_post, +) + + +class TestsEndpoints: + @classmethod + def get_user_list(cls) -> types.ModuleType: + """ + Get a list of things + """ + return get_user_list + + @classmethod + def get_basic_list_of_strings(cls) -> types.ModuleType: + """ + Get a list of strings + """ + return get_basic_list_of_strings + + @classmethod + def get_basic_list_of_integers(cls) -> types.ModuleType: + """ + Get a list of integers + """ + return get_basic_list_of_integers + + @classmethod + def get_basic_list_of_floats(cls) -> types.ModuleType: + """ + Get a list of floats + """ + return get_basic_list_of_floats + + @classmethod + def get_basic_list_of_booleans(cls) -> types.ModuleType: + """ + Get a list of booleans + """ + return get_basic_list_of_booleans + + @classmethod + def post_form_data(cls) -> types.ModuleType: + """ + Post form data + """ + return post_form_data + + @classmethod + def upload_file_tests_upload_post(cls) -> types.ModuleType: + """ + Upload a file + """ + return upload_file_tests_upload_post + + @classmethod + def json_body_tests_json_body_post(cls) -> types.ModuleType: + """ + Try sending a JSON body + """ + return json_body_tests_json_body_post + + @classmethod + def defaults_tests_defaults_post(cls) -> types.ModuleType: + """ + Defaults + """ + return defaults_tests_defaults_post + + @classmethod + def octet_stream_tests_octet_stream_get(cls) -> types.ModuleType: + """ + Octet Stream + """ + return octet_stream_tests_octet_stream_get + + @classmethod + def no_response_tests_no_response_get(cls) -> types.ModuleType: + """ + No Response + """ + return no_response_tests_no_response_get + + @classmethod + def unsupported_content_tests_unsupported_content_get(cls) -> types.ModuleType: + """ + Unsupported Content + """ + return unsupported_content_tests_unsupported_content_get + + @classmethod + def int_enum_tests_int_enum_post(cls) -> types.ModuleType: + """ + Int Enum + """ + return int_enum_tests_int_enum_post + + @classmethod + def test_inline_objects(cls) -> types.ModuleType: + """ + Test Inline Objects + """ + return test_inline_objects + + @classmethod + def optional_value_tests_optional_query_param(cls) -> types.ModuleType: + """ + Test optional query parameters + """ + return optional_value_tests_optional_query_param + + @classmethod + def token_with_cookie_auth_token_with_cookie_get(cls) -> types.ModuleType: + """ + Test optional cookie parameters + """ + return token_with_cookie_auth_token_with_cookie_get diff --git a/end_to_end_tests/regen_golden_record.py b/end_to_end_tests/regen_golden_record.py index aaa6aa850..1d4dc943d 100644 --- a/end_to_end_tests/regen_golden_record.py +++ b/end_to_end_tests/regen_golden_record.py @@ -1,12 +1,16 @@ """ Regenerate golden-record """ +import filecmp +import os import shutil +import tempfile from pathlib import Path from typer.testing import CliRunner from openapi_python_client.cli import app -if __name__ == "__main__": + +def regen_golden_record(): runner = CliRunner() openapi_path = Path(__file__).parent / "openapi.json" @@ -24,3 +28,52 @@ if result.exception: raise result.exception output_path.rename(gr_path) + + +def regen_custom_template_golden_record(): + runner = CliRunner() + openapi_path = Path(__file__).parent / "openapi.json" + tpl_dir = Path(__file__).parent / "test_custom_templates" + + gr_path = Path(__file__).parent / "golden-record" + tpl_gr_path = Path(__file__).parent / "custom-templates-golden-record" + + output_path = Path(tempfile.mkdtemp()) + config_path = Path(__file__).parent / "config.yml" + + shutil.rmtree(tpl_gr_path, ignore_errors=True) + + os.chdir(str(output_path.absolute())) + result = runner.invoke( + app, ["generate", f"--config={config_path}", f"--path={openapi_path}", f"--custom-template-path={tpl_dir}"] + ) + + if result.stdout: + generated_output_path = output_path / "my-test-api-client" + for f in generated_output_path.glob("**/*"): # nb: works for Windows and Unix + relative_to_generated = f.relative_to(generated_output_path) + gr_file = gr_path / relative_to_generated + if not gr_file.exists(): + print(f"{gr_file} does not exist, ignoring") + continue + + if not gr_file.is_file(): + continue + + if not filecmp.cmp(gr_file, f, shallow=False): + target_file = tpl_gr_path / relative_to_generated + target_dir = target_file.parent + + target_dir.mkdir(parents=True, exist_ok=True) + shutil.copy(f"{f}", f"{target_file}") + + shutil.rmtree(output_path, ignore_errors=True) + + if result.exception: + shutil.rmtree(output_path, ignore_errors=True) + raise result.exception + + +if __name__ == "__main__": + regen_golden_record() + regen_custom_template_golden_record() diff --git a/end_to_end_tests/test_custom_templates/api_init.py.jinja b/end_to_end_tests/test_custom_templates/api_init.py.jinja new file mode 100644 index 000000000..03c2a2f6f --- /dev/null +++ b/end_to_end_tests/test_custom_templates/api_init.py.jinja @@ -0,0 +1,13 @@ +""" Contains methods for accessing the API """ + +from typing import Type +{% for tag in endpoint_collections_by_tag.keys() %} +from {{ package_name }}.api.{{ tag }} import {{ utils.pascal_case(tag) }}Endpoints +{% endfor %} + +class {{ utils.pascal_case(package_name) }}Api: +{% for tag in endpoint_collections_by_tag.keys() %} + @classmethod + def {{ tag }}(cls) -> Type[{{ utils.pascal_case(tag) }}Endpoints]: + return {{ utils.pascal_case(tag) }}Endpoints +{% endfor %} diff --git a/end_to_end_tests/test_custom_templates/endpoint_init.py.jinja b/end_to_end_tests/test_custom_templates/endpoint_init.py.jinja new file mode 100644 index 000000000..57e8ba124 --- /dev/null +++ b/end_to_end_tests/test_custom_templates/endpoint_init.py.jinja @@ -0,0 +1,24 @@ +""" Contains methods for accessing the API Endpoints """ + +import types +{% for endpoint in endpoint_collection.endpoints %} +from {{ package_name }}.api.{{ endpoint_collection.tag }} import {{ utils.snake_case(endpoint.name) }} +{% endfor %} + +class {{ utils.pascal_case(endpoint_collection.tag) }}Endpoints: + +{% for endpoint in endpoint_collection.endpoints %} + + @classmethod + def {{ utils.snake_case(endpoint.name) }}(cls) -> types.ModuleType: + {% if endpoint.description %} + """ + {{ endpoint.description }} + """ + {% elif endpoint.summary %} + """ + {{ endpoint.summary }} + """ + {% endif %} + return {{ utils.snake_case(endpoint.name) }} +{% endfor %} diff --git a/end_to_end_tests/test_end_to_end.py b/end_to_end_tests/test_end_to_end.py index fa4d21598..bcc8b12e1 100644 --- a/end_to_end_tests/test_end_to_end.py +++ b/end_to_end_tests/test_end_to_end.py @@ -1,7 +1,7 @@ import shutil from filecmp import cmpfiles, dircmp from pathlib import Path -from typing import Dict, Optional +from typing import Dict, List, Optional import pytest from typer.testing import CliRunner @@ -12,8 +12,18 @@ def _compare_directories( record: Path, test_subject: Path, - expected_differences: Optional[Dict[str, str]] = None, + expected_differences: Dict[Path, str], + depth=0, ): + """ + Compare two directories and assert that only expected_differences are different + + Args: + record: Path to the expected output + test_subject: Path to the generated code being checked + expected_differences: key: path relative to generated directory, value: expected generated content + depth: Used to track recursion + """ first_printable = record.relative_to(Path.cwd()) second_printable = test_subject.relative_to(Path.cwd()) dc = dircmp(record, test_subject) @@ -22,30 +32,42 @@ def _compare_directories( pytest.fail(f"{first_printable} or {second_printable} was missing: {missing_files}", pytrace=False) expected_differences = expected_differences or {} - _, mismatch, errors = cmpfiles(record, test_subject, dc.common_files, shallow=False) - mismatch = set(mismatch) - - for file_name in mismatch | set(expected_differences.keys()): - if file_name not in expected_differences: + _, mismatches, errors = cmpfiles(record, test_subject, dc.common_files, shallow=False) + mismatches = set(mismatches) + + expected_path_mismatches = [] + for file_name in mismatches: + mismatch_file_path = test_subject.joinpath(file_name) + expected_content = expected_differences.get(mismatch_file_path) + if expected_content is None: continue - if file_name not in mismatch: - pytest.fail(f"Expected {file_name} to be different but it was not", pytrace=False) - generated = (test_subject / file_name).read_text() - assert generated == expected_differences[file_name], f"Unexpected output in {file_name}" - del expected_differences[file_name] - mismatch.remove(file_name) - - if mismatch: + + generated_content = (test_subject / file_name).read_text() + assert generated_content == expected_content, f"Unexpected output in {mismatch_file_path}" + expected_path_mismatches.append(mismatch_file_path) + + for path_mismatch in expected_path_mismatches: + matched_file_name = path_mismatch.name + mismatches.remove(matched_file_name) + del expected_differences[path_mismatch] + + if mismatches: pytest.fail( - f"{first_printable} and {second_printable} had differing files: {mismatch}, and errors {errors}", + f"{first_printable} and {second_printable} had differing files: {mismatches}, and errors {errors}", pytrace=False, ) for sub_path in dc.common_dirs: - _compare_directories(record / sub_path, test_subject / sub_path, expected_differences=expected_differences) + _compare_directories( + record / sub_path, test_subject / sub_path, expected_differences=expected_differences, depth=depth + 1 + ) + + if depth == 0 and len(expected_differences.keys()) > 0: + failure = "\n".join([f"Expected {path} to be different but it was not" for path in expected_differences.keys()]) + pytest.fail(failure, pytrace=False) -def run_e2e_test(extra_args=None, expected_differences=None): +def run_e2e_test(extra_args: List[str], expected_differences: Dict[Path, str]): runner = CliRunner() openapi_path = Path(__file__).parent / "openapi.json" config_path = Path(__file__).parent / "config.yml" @@ -60,6 +82,9 @@ def run_e2e_test(extra_args=None, expected_differences=None): if result.exit_code != 0: raise result.exception + + # Use absolute paths for expected differences for easier comparisons + expected_differences = {output_path.joinpath(key): value for key, value in expected_differences.items()} _compare_directories(gr_path, output_path, expected_differences=expected_differences) import mypy.api @@ -71,11 +96,24 @@ def run_e2e_test(extra_args=None, expected_differences=None): def test_end_to_end(): - run_e2e_test() + run_e2e_test([], {}) def test_custom_templates(): + expected_differences = {} # key: path relative to generated directory, value: expected generated content + expected_difference_paths = [ + Path("README.md"), + Path("my_test_api_client").joinpath("api", "__init__.py"), + Path("my_test_api_client").joinpath("api", "tests", "__init__.py"), + Path("my_test_api_client").joinpath("api", "default", "__init__.py"), + Path("my_test_api_client").joinpath("api", "parameters", "__init__.py"), + ] + + golden_tpls_root_dir = Path(__file__).parent.joinpath("custom-templates-golden-record") + for expected_difference_path in expected_difference_paths: + expected_differences[expected_difference_path] = (golden_tpls_root_dir / expected_difference_path).read_text() + run_e2e_test( - extra_args=["--custom-template-path=end_to_end_tests/test_custom_templates"], - expected_differences={"README.md": "my-test-api-client"}, + extra_args=["--custom-template-path=end_to_end_tests/test_custom_templates/"], + expected_differences=expected_differences, ) diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index 2a7cf574b..b1458e1a4 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -81,6 +81,15 @@ def __init__( self.version: str = config.package_version_override or openapi.version self.env.filters.update(TEMPLATE_FILTERS) + self.env.globals.update( + utils=utils, + package_name=self.package_name, + package_dir=self.package_dir, + package_description=self.package_description, + package_version=self.version, + project_name=self.project_name, + project_dir=self.project_dir, + ) def build(self) -> Sequence[GeneratorError]: """Create the project from templates""" @@ -143,9 +152,7 @@ def _create_package(self) -> None: package_init = self.package_dir / "__init__.py" package_init_template = self.env.get_template("package_init.py.jinja") - package_init.write_text( - package_init_template.render(description=self.package_description), encoding=self.file_encoding - ) + package_init.write_text(package_init_template.render(), encoding=self.file_encoding) if self.meta != MetaType.NONE: pytyped = self.package_dir / "py.typed" @@ -167,9 +174,7 @@ def _build_metadata(self) -> None: readme = self.project_dir / "README.md" readme_template = self.env.get_template("README.md.jinja") readme.write_text( - readme_template.render( - project_name=self.project_name, description=self.package_description, package_name=self.package_name - ), + readme_template.render(), encoding=self.file_encoding, ) @@ -183,12 +188,7 @@ def _build_pyproject_toml(self, *, use_poetry: bool) -> None: pyproject_template = self.env.get_template(template) pyproject_path = self.project_dir / "pyproject.toml" pyproject_path.write_text( - pyproject_template.render( - project_name=self.project_name, - package_name=self.package_name, - version=self.version, - description=self.package_description, - ), + pyproject_template.render(), encoding=self.file_encoding, ) @@ -196,12 +196,7 @@ def _build_setup_py(self) -> None: template = self.env.get_template("setup.py.jinja") path = self.project_dir / "setup.py" path.write_text( - template.render( - project_name=self.project_name, - package_name=self.package_name, - version=self.version, - description=self.package_description, - ), + template.render(), encoding=self.file_encoding, ) @@ -239,16 +234,29 @@ def _build_api(self) -> None: client_path.write_text(client_template.render(), encoding=self.file_encoding) # Generate endpoints + endpoint_collections_by_tag = self.openapi.endpoint_collections_by_tag api_dir = self.package_dir / "api" api_dir.mkdir() - api_init = api_dir / "__init__.py" - api_init.write_text('""" Contains methods for accessing the API """', encoding=self.file_encoding) + api_init_path = api_dir / "__init__.py" + api_init_template = self.env.get_template("api_init.py.jinja") + api_init_path.write_text( + api_init_template.render( + endpoint_collections_by_tag=endpoint_collections_by_tag, + ), + encoding=self.file_encoding, + ) endpoint_template = self.env.get_template("endpoint_module.py.jinja") - for tag, collection in self.openapi.endpoint_collections_by_tag.items(): + for tag, collection in endpoint_collections_by_tag.items(): tag_dir = api_dir / tag tag_dir.mkdir() - (tag_dir / "__init__.py").touch() + + endpoint_init_path = tag_dir / "__init__.py" + endpoint_init_template = self.env.get_template("endpoint_init.py.jinja") + endpoint_init_path.write_text( + endpoint_init_template.render(endpoint_collection=collection), + encoding=self.file_encoding, + ) for endpoint in collection.endpoints: module_path = tag_dir / f"{snake_case(endpoint.name)}.py" diff --git a/openapi_python_client/templates/README.md.jinja b/openapi_python_client/templates/README.md.jinja index 2a5d18d87..e6de0dda5 100644 --- a/openapi_python_client/templates/README.md.jinja +++ b/openapi_python_client/templates/README.md.jinja @@ -1,5 +1,5 @@ # {{ project_name }} -{{ description }} +{{ package_description }} ## Usage First, create a client: diff --git a/openapi_python_client/templates/api_init.py.jinja b/openapi_python_client/templates/api_init.py.jinja new file mode 100644 index 000000000..dc035f4ce --- /dev/null +++ b/openapi_python_client/templates/api_init.py.jinja @@ -0,0 +1 @@ +""" Contains methods for accessing the API """ diff --git a/openapi_python_client/templates/endpoint_init.py.jinja b/openapi_python_client/templates/endpoint_init.py.jinja new file mode 100644 index 000000000..e69de29bb diff --git a/openapi_python_client/templates/package_init.py.jinja b/openapi_python_client/templates/package_init.py.jinja index 917cd7dde..f146549d0 100644 --- a/openapi_python_client/templates/package_init.py.jinja +++ b/openapi_python_client/templates/package_init.py.jinja @@ -1,2 +1,2 @@ -""" {{ description }} """ +""" {{ package_description }} """ from .client import AuthenticatedClient, Client diff --git a/openapi_python_client/templates/pyproject.toml.jinja b/openapi_python_client/templates/pyproject.toml.jinja index 9e311a1a8..695092f48 100644 --- a/openapi_python_client/templates/pyproject.toml.jinja +++ b/openapi_python_client/templates/pyproject.toml.jinja @@ -1,7 +1,7 @@ [tool.poetry] name = "{{ project_name }}" -version = "{{ version }}" -description = "{{ description }}" +version = "{{ package_version }}" +description = "{{ package_description }}" authors = [] diff --git a/openapi_python_client/templates/setup.py.jinja b/openapi_python_client/templates/setup.py.jinja index 0dd31d23b..027120ab9 100644 --- a/openapi_python_client/templates/setup.py.jinja +++ b/openapi_python_client/templates/setup.py.jinja @@ -7,8 +7,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8") setup( name="{{ project_name }}", - version="{{ version }}", - description="{{ description }}", + version="{{ package_version }}", + description="{{ package_description }}", long_description=long_description, long_description_content_type="text/markdown", package_dir={"": "{{ package_name }}"}, diff --git a/tests/test___init__.py b/tests/test___init__.py index 0579e83f0..3e1efbd5c 100644 --- a/tests/test___init__.py +++ b/tests/test___init__.py @@ -403,11 +403,7 @@ def test__build_metadata_poetry(self, mocker): project._build_metadata() project.env.get_template.assert_has_calls([mocker.call("README.md.jinja"), mocker.call(".gitignore.jinja")]) - readme_template.render.assert_called_once_with( - description=project.package_description, - project_name=project.project_name, - package_name=project.package_name, - ) + readme_template.render.assert_called_once_with() readme_path.write_text.assert_called_once_with(readme_template.render(), encoding="utf-8") git_ignore_template.render.assert_called_once() git_ignore_path.write_text.assert_called_once_with(git_ignore_template.render(), encoding="utf-8") @@ -440,11 +436,7 @@ def test__build_metadata_setup(self, mocker): project._build_metadata() project.env.get_template.assert_has_calls([mocker.call("README.md.jinja"), mocker.call(".gitignore.jinja")]) - readme_template.render.assert_called_once_with( - description=project.package_description, - project_name=project.project_name, - package_name=project.package_name, - ) + readme_template.render.assert_called_once_with() readme_path.write_text.assert_called_once_with(readme_template.render(), encoding="utf-8") git_ignore_template.render.assert_called_once() git_ignore_path.write_text.assert_called_once_with(git_ignore_template.render(), encoding="utf-8") @@ -483,12 +475,7 @@ def test__build_pyproject_toml(self, mocker, use_poetry): project.env.get_template.assert_called_once_with(template_path) - pyproject_template.render.assert_called_once_with( - project_name=project.project_name, - package_name=project.package_name, - version=project.version, - description=project.package_description, - ) + pyproject_template.render.assert_called_once_with() pyproject_path.write_text.assert_called_once_with(pyproject_template.render(), encoding="utf-8") def test__build_setup_py(self, mocker): @@ -511,12 +498,7 @@ def test__build_setup_py(self, mocker): project.env.get_template.assert_called_once_with("setup.py.jinja") - setup_template.render.assert_called_once_with( - project_name=project.project_name, - package_name=project.package_name, - version=project.version, - description=project.package_description, - ) + setup_template.render.assert_called_once_with() setup_path.write_text.assert_called_once_with(setup_template.render(), encoding="utf-8")