Skip to content

feat: did you mean? #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/scikit_build_core/settings/skbuild_read_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import difflib
import sys
from collections.abc import Generator, Mapping
from pathlib import Path
Expand Down Expand Up @@ -50,15 +51,37 @@ def __init__(
def unrecognized_options(self) -> Generator[str, None, None]:
return self.sources.unrecognized_options(ScikitBuildSettings)

def suggestions(self, index: int) -> dict[str, list[str]]:
all_options = list(self.sources[index].all_option_names(ScikitBuildSettings))
result: dict[str, list[str]] = {
k: [] for k in self.sources[index].unrecognized_options(ScikitBuildSettings)
}
for option in result:
possibilities = {
".".join(k.split(".")[: option.count(".") + 1]) for k in all_options
}
result[option] = difflib.get_close_matches(option, possibilities, n=3)

return result

def print_suggestions(self) -> None:
for index in (1, 2):
name = {1: "config-settings", 2: "pyproject.toml"}[index]
suggestions_dict = self.suggestions(index)
if suggestions_dict:
rich_print(f"[red][bold]ERROR:[/bold] Unrecognized options in {name}:")
for option, suggestions in suggestions_dict.items():
rich_print(f" [red]{option}", end="")
if suggestions:
sugstr = ", ".join(suggestions)
rich_print(f"[yellow] -> Did you mean: {sugstr}?", end="")
rich_print()

def validate_may_exit(self) -> None:
unrecognized = list(self.unrecognized_options())
if unrecognized:
if self.settings.strict_config:
sys.stdout.flush()
rich_print(
"[red][bold]ERROR:[/bold] Unrecognized options:", file=sys.stderr
)
for option in unrecognized:
rich_print(f" [red]{option}", file=sys.stderr)
self.print_suggestions()
raise SystemExit(7)
logger.warning("Unrecognized options: {}", ", ".join(unrecognized))
41 changes: 39 additions & 2 deletions src/scikit_build_core/settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import os
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator, Iterator, Mapping, Sequence
from typing import Any, TypeVar, Union

from .._compat.builtins import ExceptionGroup
Expand Down Expand Up @@ -102,6 +102,18 @@ def _get_inner_type(target: type[Any]) -> type[Any]:
raise AssertionError("Expected a list or dict")


def _nested_dataclass_to_names(target: type[Any], *inner: str) -> Iterator[list[str]]:
"""
Yields each entry, like ("a", "b", "c") for a.b.c
"""

if dataclasses.is_dataclass(target):
for field in dataclasses.fields(target):
yield from _nested_dataclass_to_names(field.type, *inner, field.name)
else:
yield list(inner)


class Source(Protocol):
def has_item(self, *fields: str, is_dict: bool) -> bool:
"""
Expand All @@ -121,6 +133,9 @@ def convert(cls, item: Any, target: type[Any]) -> object:
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
...

def all_option_names(self, target: type[Any]) -> Iterator[str]:
...


class EnvSource:
"""
Expand Down Expand Up @@ -170,6 +185,11 @@ def convert(cls, item: str, target: type[Any]) -> object:
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
yield from ()

def all_option_names(self, target: type[Any]) -> Iterator[str]:
prefix = [self.prefix] if self.prefix else []
for names in _nested_dataclass_to_names(target):
yield "_".join(prefix + names).upper()


def _unrecognized_dict(
settings: Mapping[str, Any], options: object, above: Sequence[str]
Expand Down Expand Up @@ -282,6 +302,11 @@ def unrecognized_options(self, options: object) -> Generator[str, None, None]:
if _get_target_raw_type(outer_option) == dict:
continue

def all_option_names(self, target: type[Any]) -> Iterator[str]:
for names in _nested_dataclass_to_names(target):
dash_names = [name.replace("_", "-") for name in names]
yield ".".join((*self.prefixes, *dash_names))


class TOMLSource:
def __init__(self, *prefixes: str, settings: Mapping[str, Any]):
Expand Down Expand Up @@ -322,11 +347,19 @@ def convert(cls, item: Any, target: type[Any]) -> object:
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
yield from _unrecognized_dict(self.settings, options, self.prefixes)

def all_option_names(self, target: type[Any]) -> Iterator[str]:
for names in _nested_dataclass_to_names(target):
dash_names = [name.replace("_", "-") for name in names]
yield ".".join((*self.prefixes, *dash_names))


class SourceChain:
def __init__(self, *sources: Source):
def __init__(self, *sources: Source) -> None:
self.sources = sources

def __getitem__(self, index: int) -> Source:
return self.sources[index]

def has_item(self, *fields: str, is_dict: bool) -> bool:
for source in self.sources:
if source.has_item(*fields, is_dict=is_dict):
Expand Down Expand Up @@ -399,3 +432,7 @@ def convert_target(self, target: type[T], *prefixes: str) -> T:
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
for source in self.sources:
yield from source.unrecognized_options(options)

def all_option_names(self, target: type[Any]) -> Iterator[str]:
for source in self.sources:
yield from source.all_option_names(target)
43 changes: 43 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ def test_toml():
assert settings.nine == {"thing": 8}


def test_all_names():

keys = [x.name for x in dataclasses.fields(SettingChecker)]

envame = [f"SKBUILD_{x.upper()}" for x in keys]
assert list(EnvSource("SKBUILD").all_option_names(SettingChecker)) == envame

assert list(ConfSource(settings={}).all_option_names(SettingChecker)) == keys
skkeys = [f"skbuild.{x}" for x in keys]
assert (
list(ConfSource("skbuild", settings={}).all_option_names(SettingChecker))
== skkeys
)

assert list(TOMLSource(settings={}).all_option_names(SettingChecker)) == keys
assert (
list(TOMLSource("skbuild", settings={}).all_option_names(SettingChecker))
== skkeys
)


@dataclasses.dataclass
class NestedSettingChecker:
zero: Path
Expand Down Expand Up @@ -247,6 +268,28 @@ def test_toml_nested():
assert settings.three == 3


def test_all_names_nested():
keys_two = [x.name for x in dataclasses.fields(SettingChecker)]
ikeys = [["zero"], ["one"], *[["two", k] for k in keys_two], ["three"]]

envame = [f"SKBUILD_{'_'.join(x).upper()}" for x in ikeys]
assert list(EnvSource("SKBUILD").all_option_names(NestedSettingChecker)) == envame

keys = [".".join(x) for x in ikeys]
assert list(ConfSource(settings={}).all_option_names(NestedSettingChecker)) == keys
skkeys = [f"skbuild.{x}" for x in keys]
assert (
list(ConfSource("skbuild", settings={}).all_option_names(NestedSettingChecker))
== skkeys
)

assert list(TOMLSource(settings={}).all_option_names(NestedSettingChecker)) == keys
assert (
list(TOMLSource("skbuild", settings={}).all_option_names(NestedSettingChecker))
== skkeys
)


@dataclasses.dataclass
class SettingBools:
false: bool = False
Expand Down
28 changes: 26 additions & 2 deletions tests/test_skbuild_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import textwrap

import pytest
Expand Down Expand Up @@ -203,7 +204,7 @@ def test_skbuild_settings_pyproject_toml(tmp_path, monkeypatch):
assert settings.minimum_version == "0.1"


def test_skbuild_settings_pyproject_toml_broken(tmp_path, monkeypatch):
def test_skbuild_settings_pyproject_toml_broken(tmp_path, capsys):
pyproject_toml = tmp_path / "pyproject.toml"
pyproject_toml.write_text(
textwrap.dedent(
Expand All @@ -229,8 +230,19 @@ def test_skbuild_settings_pyproject_toml_broken(tmp_path, monkeypatch):
with pytest.raises(SystemExit):
settings_reader.validate_may_exit()

ex = capsys.readouterr().out
ex = re.sub(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))", "", ex)
assert (
ex.split()
== """\
ERROR: Unrecognized options in pyproject.toml:
tool.scikit-build.cmake.minimum-verison -> Did you mean: tool.scikit-build.cmake.minimum-version, tool.scikit-build.minimum-version, tool.scikit-build.ninja.minimum-version?
tool.scikit-build.logger -> Did you mean: tool.scikit-build.logging, tool.scikit-build.wheel, tool.scikit-build.cmake?
""".split()
)


def test_skbuild_settings_pyproject_conf_broken(tmp_path):
def test_skbuild_settings_pyproject_conf_broken(tmp_path, capsys):
pyproject_toml = tmp_path / "pyproject.toml"
pyproject_toml.write_text("", encoding="utf-8")

Expand All @@ -249,3 +261,15 @@ def test_skbuild_settings_pyproject_conf_broken(tmp_path):

with pytest.raises(SystemExit):
settings_reader.validate_may_exit()

ex = capsys.readouterr().out
# Filter terminal color codes
ex = re.sub(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))", "", ex)
assert (
ex.split()
== """\
ERROR: Unrecognized options in config-settings:
cmake.minimum-verison -> Did you mean: cmake.minimum-version, minimum-version, ninja.minimum-version?
logger -> Did you mean: logging?
""".split()
)