diff --git a/pyproject.toml b/pyproject.toml index 7d325e5..11bec85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ ] dependencies = [ "typer >= 0.12.3", + "tomli >= 2.0.1; python_version < '3.11'", ] [project.optional-dependencies] diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py index d5bcb8e..e607feb 100644 --- a/src/fastapi_cli/cli.py +++ b/src/fastapi_cli/cli.py @@ -12,6 +12,7 @@ from fastapi_cli.exceptions import FastAPICLIException from . import __version__ +from .config import CommandWithProjectConfig from .logging import setup_logging app = typer.Typer(rich_markup_mode="rich") @@ -100,12 +101,13 @@ def _run( ) -@app.command() +@app.command(cls=CommandWithProjectConfig) def dev( path: Annotated[ Union[Path, None], typer.Argument( - help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried." + help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried.", + envvar=["FASTAPI_DEV_PATH", "FASTAPI_PATH"], ), ] = None, *, @@ -183,12 +185,13 @@ def dev( ) -@app.command() +@app.command(cls=CommandWithProjectConfig) def run( path: Annotated[ Union[Path, None], typer.Argument( - help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried." + help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried.", + envvar=["FASTAPI_RUN_PATH", "FASTAPI_PATH"], ), ] = None, *, @@ -274,4 +277,4 @@ def run( def main() -> None: - app() + app(auto_envvar_prefix="FASTAPI") diff --git a/src/fastapi_cli/config.py b/src/fastapi_cli/config.py new file mode 100644 index 0000000..5a2ec8f --- /dev/null +++ b/src/fastapi_cli/config.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import logging +import sys +from pathlib import Path +from typing import Any, Sequence + +from click import BadParameter, Context +from typer.core import TyperCommand + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + +logger = logging.getLogger(__name__) + + +def get_toml_key(config: dict[str, Any], keys: Sequence[str]) -> dict[str, Any]: + for key in keys: + config = config.get(key, {}) + return config + + +def read_pyproject_file(keys: Sequence[str]) -> dict[str, Any] | None: + path = Path("pyproject.toml") + if not path.exists(): + return None + + with path.open("rb") as f: + data = tomllib.load(f) + + config = get_toml_key(data, keys) + return config or None + + +class CommandWithProjectConfig(TyperCommand): + """Command class which loads parameters from a pyproject.toml file. + + The table `tool.fastapi.cli` will be used. An additional subtable for the + running command will also be used. e.g. `tool.fastapi.cli.dev`. Options + on subcommand tables will override options from the cli table. + + Example: + + ```toml + [tool.fastapi.cli] + path = "asgi.py" + app = "application" + + [tool.fastapi.cli.dev] + host = "0.0.0.0" + port = 5000 + + [tool.fastapi.cli.run] + reload = true + ``` + """ + + toml_keys = ("tool", "fastapi", "cli") + + def load_config_table( + self, + ctx: Context, + config: dict[str, Any], + config_path: str | None = None, + ) -> None: + if config_path is not None: + config = config.get(config_path, {}) + if not config: + return + for param in ctx.command.params: + param_name = param.name or "" + if param_name in config: + try: + value = param.type_cast_value(ctx, config[param_name]) + except (TypeError, BadParameter) as e: + keys: list[str] = list(self.toml_keys) + if config_path is not None: + keys.append(config_path) + keys.append(param_name) + full_path = ".".join(keys) + ctx.fail(f"Error parsing pyproject.toml: key '{full_path}': {e}") + else: + ctx.params[param_name] = value + + def invoke(self, ctx: Context) -> Any: + config = read_pyproject_file(self.toml_keys) + if config is not None: + logger.info("Loading configuration from pyproject.toml") + command_name = ctx.command.name or "" + self.load_config_table(ctx, config) + self.load_config_table(ctx, config, command_name) + + return super().invoke(ctx) diff --git a/tests/assets/projects/bad_configured_app/app.py b/tests/assets/projects/bad_configured_app/app.py new file mode 100644 index 0000000..6cdf1de --- /dev/null +++ b/tests/assets/projects/bad_configured_app/app.py @@ -0,0 +1,8 @@ +from fastapi import FastAPI + +app = FastAPI() + + +@app.get("/") +def app_root(): + return {"message": "badly configured app"} diff --git a/tests/assets/projects/bad_configured_app/pyproject.toml b/tests/assets/projects/bad_configured_app/pyproject.toml new file mode 100644 index 0000000..75cbbbf --- /dev/null +++ b/tests/assets/projects/bad_configured_app/pyproject.toml @@ -0,0 +1,2 @@ +[tool.fastapi.cli.run] +port = "http" diff --git a/tests/assets/projects/configured_app/pyproject.toml b/tests/assets/projects/configured_app/pyproject.toml new file mode 100644 index 0000000..5f1f4f0 --- /dev/null +++ b/tests/assets/projects/configured_app/pyproject.toml @@ -0,0 +1,3 @@ +[tool.fastapi.cli] +path = "server.py" +app = "application" diff --git a/tests/assets/projects/configured_app/server.py b/tests/assets/projects/configured_app/server.py new file mode 100644 index 0000000..2364e73 --- /dev/null +++ b/tests/assets/projects/configured_app/server.py @@ -0,0 +1,8 @@ +from fastapi import FastAPI + +application = FastAPI() + + +@application.get("/") +def app_root(): + return {"message": "configured app"} diff --git a/tests/assets/projects/configured_app_subtable/app.py b/tests/assets/projects/configured_app_subtable/app.py new file mode 100644 index 0000000..b5b10be --- /dev/null +++ b/tests/assets/projects/configured_app_subtable/app.py @@ -0,0 +1,8 @@ +from fastapi import FastAPI + +app = FastAPI() + + +@app.get("/") +def app_root(): + return {"message": "configured app with subcommand config"} diff --git a/tests/assets/projects/configured_app_subtable/pyproject.toml b/tests/assets/projects/configured_app_subtable/pyproject.toml new file mode 100644 index 0000000..ea56896 --- /dev/null +++ b/tests/assets/projects/configured_app_subtable/pyproject.toml @@ -0,0 +1,11 @@ +[tool.fastapi.cli] +# global option +port = 8001 + +[tool.fastapi.cli.run] +reload = true +workers = 4 + +[tool.fastapi.cli.dev] +# overrides global option +port = 8002 diff --git a/tests/conftest.py b/tests/conftest.py index 955fd22..6ca44a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,14 @@ +import inspect import sys +from pathlib import Path from typing import Generator import pytest from fastapi_cli.logging import setup_logging from typer import rich_utils +assets_path = Path(__file__).parent / "assets" + @pytest.fixture(autouse=True) def reset_syspath() -> Generator[None, None, None]: @@ -21,3 +25,23 @@ def setup_terminal() -> None: rich_utils.FORCE_TERMINAL = False setup_logging(terminal_width=3000) return + + +@pytest.fixture(autouse=True) +def asset_import_cleaner() -> Generator[None, None, None]: + existing_imports = set(sys.modules.keys()) + try: + yield + finally: + # clean up imports + new_imports = set(sys.modules.keys()) ^ existing_imports + for name in new_imports: + try: + mod_file = inspect.getfile(sys.modules[name]) + except TypeError: # pragma: no cover + # builtin, ignore + pass + else: + # only clean up imports from the test directory + if mod_file.startswith(str(assets_path)): + del sys.modules[name] diff --git a/tests/test_cli.py b/tests/test_cli.py index 44c14d2..2caac63 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import subprocess import sys from pathlib import Path +from typing import Any from unittest.mock import patch +import pytest import uvicorn from fastapi_cli.cli import app from typer.testing import CliRunner @@ -82,6 +86,47 @@ def test_dev_args() -> None: assert "│ fastapi run" in result.output +def test_project_run() -> None: + with changing_dir(assets_path / "projects/configured_app"): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke(app, ["run"]) + assert result.exit_code == 0, result.output + assert mock_run.called + assert mock_run.call_args + assert mock_run.call_args.kwargs == { + "app": "server:application", + "host": "0.0.0.0", + "port": 8000, + "reload": False, + "workers": None, + "root_path": "", + "proxy_headers": True, + } + + +@pytest.mark.parametrize( + ("command", "kwargs"), + [ + ("run", {"host": "0.0.0.0", "port": 8001, "workers": 4}), + ("dev", {"host": "127.0.0.1", "port": 8002, "workers": None}), + ], +) +def test_project_run_subconfigure(command: str, kwargs: dict[str, Any]) -> None: + with changing_dir(assets_path / "projects/configured_app_subtable"): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke(app, [command]) + assert result.exit_code == 0, result.output + assert mock_run.called + assert mock_run.call_args + assert mock_run.call_args.kwargs == { + "app": "app:app", + "reload": True, + "root_path": "", + "proxy_headers": True, + **kwargs, + } + + def test_run() -> None: with changing_dir(assets_path): with patch.object(uvicorn, "run") as mock_run: @@ -159,6 +204,16 @@ def test_run_error() -> None: assert "Path does not exist non_existing_file.py" in result.output +def test_project_config_error() -> None: + with changing_dir(assets_path / "projects/bad_configured_app"): + result = runner.invoke(app, ["run"]) + assert result.exit_code == 2, result.output + assert ( + "Error parsing pyproject.toml: key 'tool.fastapi.cli.run.port'" + in result.output + ) + + def test_dev_help() -> None: result = runner.invoke(app, ["dev", "--help"]) assert result.exit_code == 0, result.output