Skip to content

CLI support for optional and variadic positional args #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,10 @@ print(User().model_dump())

### Subcommands and Positional Arguments

Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.
Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. The
subcommand annotation can only be applied to required fields (i.e. fields that do not have a default value).
Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses
`dataclass`.

Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.
Expand Down
38 changes: 26 additions & 12 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ def _load_env_vars(
if subcommand_dest not in selected_subcommands:
parsed_args[subcommand_dest] = self.cli_parse_none_str

parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')}
parsed_args = {
key: val
for key, val in parsed_args.items()
if not key.endswith(':subcommand') and val is not PydanticUndefined
}
if selected_subcommands:
last_selected_subcommand = max(selected_subcommands, key=len)
if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name):
Expand Down Expand Up @@ -1511,12 +1515,9 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
)
subcommand_args.append((field_name, field_info))
elif _CliPositionalArg in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
else:
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
Expand Down Expand Up @@ -1727,11 +1728,7 @@ def _add_parser_args(
self._cli_dict_args[kwargs['dest']] = field_info.annotation

if _CliPositionalArg in field_info.metadata:
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
arg_names = [kwargs['dest']]
del kwargs['dest']
del kwargs['required']
flag_prefix = ''
arg_names, flag_prefix = self._convert_positional_arg(kwargs, field_info, preferred_alias)

self._convert_bool_flag(kwargs, field_info, model_default)

Expand Down Expand Up @@ -1787,6 +1784,23 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}'
)

def _convert_positional_arg(
self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str
) -> tuple[list[str], str]:
flag_prefix = ''
arg_names = [kwargs['dest']]
kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

# Note: For positional args, we must strictly look at field_info.is_required instead of our derived
# kwargs['required'].
if not field_info.is_required():
kwargs['nargs'] = '?'

del kwargs['dest']
del kwargs['required']
return arg_names, flag_prefix

def _get_arg_names(
self,
arg_prefix: str,
Expand Down
25 changes: 16 additions & 9 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,22 @@ class Cfg(BaseSettings):
assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}}


def test_cli_optional_positional_arg(env):
class Main(BaseSettings):
model_config = SettingsConfigDict(
cli_parse_args=True,
cli_enforce_required=True,
)

value: CliPositionalArg[int] = 123

assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 123}

env.set('VALUE', '456')
assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 456}

assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789}

def test_cli_enums(capsys, monkeypatch):
class Pet(IntEnum):
dog = 0
Expand Down Expand Up @@ -1415,15 +1431,6 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True):

PositionalArgNotOutermost()

with pytest.raises(
SettingsError, match='positional argument PositionalArgHasDefault.pos_arg has a default value'
):

class PositionalArgHasDefault(BaseSettings, cli_parse_args=True):
pos_arg: CliPositionalArg[str] = 'bad'

PositionalArgHasDefault()

with pytest.raises(
SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved <class 'str'>")
):
Expand Down
Loading