Skip to content

Commit b55eefd

Browse files
committed
feat: did you mean?
Signed-off-by: Henry Schreiner <[email protected]>
1 parent 362a5a1 commit b55eefd

File tree

4 files changed

+139
-9
lines changed

4 files changed

+139
-9
lines changed

src/scikit_build_core/settings/skbuild_read_settings.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import difflib
34
import sys
45
from collections.abc import Generator, Mapping
56
from pathlib import Path
@@ -50,15 +51,41 @@ def __init__(
5051
def unrecognized_options(self) -> Generator[str, None, None]:
5152
return self.sources.unrecognized_options(ScikitBuildSettings)
5253

54+
def suggestions(self, index: int) -> dict[str, list[str]]:
55+
all_options = list(self.sources[index].all_option_names(ScikitBuildSettings))
56+
result: dict[str, list[str]] = {
57+
k: [] for k in self.sources[index].unrecognized_options(ScikitBuildSettings)
58+
}
59+
for option in result:
60+
prefixed = [x for x in all_options if x.startswith(option)]
61+
if prefixed:
62+
result[option] = prefixed
63+
else:
64+
possibilities = {
65+
".".join(k.split(".")[: option.count(".") + 1]) for k in all_options
66+
}
67+
result[option] = difflib.get_close_matches(option, possibilities, n=3)
68+
69+
return result
70+
71+
def print_suggestions(self) -> None:
72+
for index in (1, 2):
73+
name = {1: "config-settings", 2: "pyproject.toml"}[index]
74+
suggestions_dict = self.suggestions(index)
75+
if suggestions_dict:
76+
rich_print(f"[red][bold]ERROR:[/bold] Unrecognized options in {name}:")
77+
for option, suggestions in suggestions_dict.items():
78+
rich_print(f" [red]{option}", end="")
79+
if suggestions:
80+
sugstr = ", ".join(suggestions)
81+
rich_print(f"[yellow] -> Did you mean: {sugstr}?", end="")
82+
rich_print()
83+
5384
def validate_may_exit(self) -> None:
5485
unrecognized = list(self.unrecognized_options())
5586
if unrecognized:
5687
if self.settings.strict_config:
5788
sys.stdout.flush()
58-
rich_print(
59-
"[red][bold]ERROR:[/bold] Unrecognized options:", file=sys.stderr
60-
)
61-
for option in unrecognized:
62-
rich_print(f" [red]{option}", file=sys.stderr)
89+
self.print_suggestions()
6390
raise SystemExit(7)
6491
logger.warning("Unrecognized options: {}", ", ".join(unrecognized))

src/scikit_build_core/settings/sources.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
import os
5-
from collections.abc import Generator, Mapping, Sequence
5+
from collections.abc import Generator, Iterator, Mapping, Sequence
66
from typing import Any, TypeVar, Union
77

88
from .._compat.builtins import ExceptionGroup
@@ -102,6 +102,18 @@ def _get_inner_type(target: type[Any]) -> type[Any]:
102102
raise AssertionError("Expected a list or dict")
103103

104104

105+
def _nested_dataclass_to_names(target: type[Any], *inner: str) -> Iterator[list[str]]:
106+
"""
107+
Yields each entry, like ("a", "b", "c") for a.b.c
108+
"""
109+
110+
if dataclasses.is_dataclass(target):
111+
for field in dataclasses.fields(target):
112+
yield from _nested_dataclass_to_names(field.type, *inner, field.name)
113+
else:
114+
yield list(inner)
115+
116+
105117
class Source(Protocol):
106118
def has_item(self, *fields: str, is_dict: bool) -> bool:
107119
"""
@@ -121,6 +133,9 @@ def convert(cls, item: Any, target: type[Any]) -> object:
121133
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
122134
...
123135

136+
def all_option_names(self, target: type[Any]) -> Iterator[str]:
137+
...
138+
124139

125140
class EnvSource:
126141
"""
@@ -170,6 +185,11 @@ def convert(cls, item: str, target: type[Any]) -> object:
170185
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
171186
yield from ()
172187

188+
def all_option_names(self, target: type[Any]) -> Iterator[str]:
189+
prefix = [self.prefix] if self.prefix else []
190+
for names in _nested_dataclass_to_names(target):
191+
yield "_".join(prefix + names).upper()
192+
173193

174194
def _unrecognized_dict(
175195
settings: Mapping[str, Any], options: object, above: Sequence[str]
@@ -282,6 +302,11 @@ def unrecognized_options(self, options: object) -> Generator[str, None, None]:
282302
if _get_target_raw_type(outer_option) == dict:
283303
continue
284304

305+
def all_option_names(self, target: type[Any]) -> Iterator[str]:
306+
for names in _nested_dataclass_to_names(target):
307+
dash_names = [name.replace("_", "-") for name in names]
308+
yield ".".join((*self.prefixes, *dash_names))
309+
285310

286311
class TOMLSource:
287312
def __init__(self, *prefixes: str, settings: Mapping[str, Any]):
@@ -322,11 +347,22 @@ def convert(cls, item: Any, target: type[Any]) -> object:
322347
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
323348
yield from _unrecognized_dict(self.settings, options, self.prefixes)
324349

350+
def all_option_names(self, target: type[Any]) -> Iterator[str]:
351+
for names in _nested_dataclass_to_names(target):
352+
dash_names = [name.replace("_", "-") for name in names]
353+
yield ".".join((*self.prefixes, *dash_names))
354+
325355

326356
class SourceChain:
327-
def __init__(self, *sources: Source):
357+
def __init__(self, *sources: Source) -> None:
328358
self.sources = sources
329359

360+
def __iter__(self) -> Iterator[Source]:
361+
return iter(self.sources)
362+
363+
def __getitem__(self, index: int) -> Source:
364+
return self.sources[index]
365+
330366
def has_item(self, *fields: str, is_dict: bool) -> bool:
331367
for source in self.sources:
332368
if source.has_item(*fields, is_dict=is_dict):
@@ -399,3 +435,7 @@ def convert_target(self, target: type[T], *prefixes: str) -> T:
399435
def unrecognized_options(self, options: object) -> Generator[str, None, None]:
400436
for source in self.sources:
401437
yield from source.unrecognized_options(options)
438+
439+
def all_option_names(self, target: type[Any]) -> Iterator[str]:
440+
for source in self.sources:
441+
yield from source.all_option_names(target)

tests/test_settings.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,27 @@ def test_toml():
149149
assert settings.nine == {"thing": 8}
150150

151151

152+
def test_all_names():
153+
154+
keys = [x.name for x in dataclasses.fields(SettingChecker)]
155+
156+
envame = [f"SKBUILD_{x.upper()}" for x in keys]
157+
assert list(EnvSource("SKBUILD").all_option_names(SettingChecker)) == envame
158+
159+
assert list(ConfSource(settings={}).all_option_names(SettingChecker)) == keys
160+
skkeys = [f"skbuild.{x}" for x in keys]
161+
assert (
162+
list(ConfSource("skbuild", settings={}).all_option_names(SettingChecker))
163+
== skkeys
164+
)
165+
166+
assert list(TOMLSource(settings={}).all_option_names(SettingChecker)) == keys
167+
assert (
168+
list(TOMLSource("skbuild", settings={}).all_option_names(SettingChecker))
169+
== skkeys
170+
)
171+
172+
152173
@dataclasses.dataclass
153174
class NestedSettingChecker:
154175
zero: Path
@@ -247,6 +268,28 @@ def test_toml_nested():
247268
assert settings.three == 3
248269

249270

271+
def test_all_names_nested():
272+
keys_two = [x.name for x in dataclasses.fields(SettingChecker)]
273+
ikeys = [["zero"], ["one"], *[["two", k] for k in keys_two], ["three"]]
274+
275+
envame = [f"SKBUILD_{'_'.join(x).upper()}" for x in ikeys]
276+
assert list(EnvSource("SKBUILD").all_option_names(NestedSettingChecker)) == envame
277+
278+
keys = [".".join(x) for x in ikeys]
279+
assert list(ConfSource(settings={}).all_option_names(NestedSettingChecker)) == keys
280+
skkeys = [f"skbuild.{x}" for x in keys]
281+
assert (
282+
list(ConfSource("skbuild", settings={}).all_option_names(NestedSettingChecker))
283+
== skkeys
284+
)
285+
286+
assert list(TOMLSource(settings={}).all_option_names(NestedSettingChecker)) == keys
287+
assert (
288+
list(TOMLSource("skbuild", settings={}).all_option_names(NestedSettingChecker))
289+
== skkeys
290+
)
291+
292+
250293
@dataclasses.dataclass
251294
class SettingBools:
252295
false: bool = False

tests/test_skbuild_settings.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def test_skbuild_settings_pyproject_toml(tmp_path, monkeypatch):
203203
assert settings.minimum_version == "0.1"
204204

205205

206-
def test_skbuild_settings_pyproject_toml_broken(tmp_path, monkeypatch):
206+
def test_skbuild_settings_pyproject_toml_broken(tmp_path, capsys):
207207
pyproject_toml = tmp_path / "pyproject.toml"
208208
pyproject_toml.write_text(
209209
textwrap.dedent(
@@ -229,8 +229,18 @@ def test_skbuild_settings_pyproject_toml_broken(tmp_path, monkeypatch):
229229
with pytest.raises(SystemExit):
230230
settings_reader.validate_may_exit()
231231

232+
ex = capsys.readouterr().out
233+
assert (
234+
ex.split()
235+
== """\
236+
ERROR: Unrecognized options in pyproject.toml:
237+
tool.scikit-build.cmake.minimum-verison -> Did you mean: tool.scikit-build.cmake.minimum-version, tool.scikit-build.minimum-version, tool.scikit-build.ninja.minimum-version?
238+
tool.scikit-build.logger -> Did you mean: tool.scikit-build.logging, tool.scikit-build.wheel, tool.scikit-build.cmake?
239+
""".split()
240+
)
241+
232242

233-
def test_skbuild_settings_pyproject_conf_broken(tmp_path):
243+
def test_skbuild_settings_pyproject_conf_broken(tmp_path, capsys):
234244
pyproject_toml = tmp_path / "pyproject.toml"
235245
pyproject_toml.write_text("", encoding="utf-8")
236246

@@ -249,3 +259,13 @@ def test_skbuild_settings_pyproject_conf_broken(tmp_path):
249259

250260
with pytest.raises(SystemExit):
251261
settings_reader.validate_may_exit()
262+
263+
ex = capsys.readouterr().out
264+
assert (
265+
ex.split()
266+
== """\
267+
ERROR: Unrecognized options in config-settings:
268+
cmake.minimum-verison -> Did you mean: cmake.minimum-version, minimum-version, ninja.minimum-version?
269+
logger -> Did you mean: logging?
270+
""".split()
271+
)

0 commit comments

Comments
 (0)