Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI bool flags #365

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,75 @@ options:
"""
```

#### CLI Boolean Flags

Change whether boolean fields should be explicit or implicit by default using the `cli_implicit_flags` setting. By
default, boolean fields are "explicit", meaning a boolean value must be explicitly provided on the CLI, e.g.
`--flag=True`. Conversely, boolean fields that are "implicit" derive the value from the flag itself, e.g.
`--flag,--no-flag`, which removes the need for an explicit value to be passed.

Additionally, the provided `CliImplicitFlag` and `CliExplicitFlag` annotations can be used for more granular control
when necessary.

!!! note
For `python < 3.9`:
* The `--no-flag` option is not generated due to an underlying `argparse` limitation.
* The `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional bool fields.

```py
from pydantic_settings import BaseSettings, CliExplicitFlag, CliImplicitFlag


class ExplicitSettings(BaseSettings, cli_parse_args=True):
"""Boolean fields are explicit by default."""

explicit_req: bool
"""
--explicit_req bool (required)
"""

explicit_opt: bool = False
"""
--explicit_opt bool (default: False)
"""

# Booleans are explicit by default, so must override implicit flags with annotation
implicit_req: CliImplicitFlag[bool]
"""
--implicit_req, --no-implicit_req (required)
"""

implicit_opt: CliImplicitFlag[bool] = False
"""
--implicit_opt, --no-implicit_opt (default: False)
"""


class ImplicitSettings(BaseSettings, cli_parse_args=True, cli_implicit_flags=True):
"""With cli_implicit_flags=True, boolean fields are implicit by default."""

# Booleans are implicit by default, so must override explicit flags with annotation
explicit_req: CliExplicitFlag[bool]
"""
--explicit_req bool (required)
"""

explicit_opt: CliExplicitFlag[bool] = False
"""
--explicit_opt bool (default: False)
"""

implicit_req: bool
"""
--implicit_req, --no-implicit_req (required)
"""

implicit_opt: bool = False
"""
--implicit_opt, --no-implicit_opt (default: False)
"""
```

#### Change Whether CLI Should Exit on Error

Change whether the CLI internal parser will exit on error or raise a `SettingsError` exception by using
Expand Down
4 changes: 4 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .main import BaseSettings, SettingsConfigDict
from .sources import (
AzureKeyVaultSettingsSource,
CliExplicitFlag,
CliImplicitFlag,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
Expand All @@ -24,6 +26,8 @@
'CliSettingsSource',
'CliSubCommand',
'CliPositionalArg',
'CliExplicitFlag',
'CliImplicitFlag',
'InitSettingsSource',
'JsonConfigSettingsSource',
'PyprojectTomlConfigSettingsSource',
Expand Down
11 changes: 11 additions & 0 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SettingsConfigDict(ConfigDict, total=False):
cli_use_class_docs_for_groups: bool
cli_exit_on_error: bool
cli_prefix: str
cli_implicit_flags: bool | None
secrets_dir: str | Path | None
json_file: PathType | None
json_file_encoding: str | None
Expand Down Expand Up @@ -114,6 +115,8 @@ class BaseSettings(BaseModel):
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
Defaults to `True`.
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
_cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
(e.g. --flag, --no-flag). Defaults to `False`.
_secrets_dir: The secret files directory. Defaults to `None`.
"""

Expand All @@ -137,6 +140,7 @@ def __init__(
_cli_use_class_docs_for_groups: bool | None = None,
_cli_exit_on_error: bool | None = None,
_cli_prefix: str | None = None,
_cli_implicit_flags: bool | None = None,
_secrets_dir: str | Path | None = None,
**values: Any,
) -> None:
Expand All @@ -162,6 +166,7 @@ def __init__(
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
_cli_exit_on_error=_cli_exit_on_error,
_cli_prefix=_cli_prefix,
_cli_implicit_flags=_cli_implicit_flags,
_secrets_dir=_secrets_dir,
)
)
Expand Down Expand Up @@ -211,6 +216,7 @@ def _settings_build_values(
_cli_use_class_docs_for_groups: bool | None = None,
_cli_exit_on_error: bool | None = None,
_cli_prefix: str | None = None,
_cli_implicit_flags: bool | None = None,
_secrets_dir: str | Path | None = None,
) -> dict[str, Any]:
# Determine settings config values
Expand Down Expand Up @@ -260,6 +266,9 @@ def _settings_build_values(
_cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error')
)
cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix')
cli_implicit_flags = (
_cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags')
)

secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')

Expand Down Expand Up @@ -311,6 +320,7 @@ def _settings_build_values(
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
cli_exit_on_error=cli_exit_on_error,
cli_prefix=cli_prefix,
cli_implicit_flags=cli_implicit_flags,
case_sensitive=case_sensitive,
)
if cli_settings_source is None
Expand Down Expand Up @@ -358,6 +368,7 @@ def _settings_build_values(
cli_use_class_docs_for_groups=False,
cli_exit_on_error=True,
cli_prefix='',
cli_implicit_flags=False,
json_file=None,
json_file_encoding=None,
yaml_file=None,
Expand Down
58 changes: 58 additions & 0 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import typing
import warnings
from abc import ABC, abstractmethod

if sys.version_info >= (3, 9):
from argparse import BooleanOptionalAction
from argparse import SUPPRESS, ArgumentParser, Namespace, RawDescriptionHelpFormatter, _SubParsersAction
from collections import deque
from dataclasses import is_dataclass
Expand Down Expand Up @@ -124,6 +127,14 @@ class _CliPositionalArg:
pass


class _CliImplicitFlag:
pass


class _CliExplicitFlag:
pass


class _CliInternalArgParser(ArgumentParser):
def __init__(self, cli_exit_on_error: bool = True, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand All @@ -138,6 +149,9 @@ def error(self, message: str) -> NoReturn:
T = TypeVar('T')
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
CliPositionalArg = Annotated[T, _CliPositionalArg]
_CliBoolFlag = TypeVar('_CliBoolFlag', bound=bool)
CliImplicitFlag = Annotated[_CliBoolFlag, _CliImplicitFlag]
CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag]


class EnvNoneType(str):
Expand Down Expand Up @@ -905,6 +919,8 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]):
cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
Defaults to `True`.
cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "".
cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
(e.g. --flag, --no-flag). Defaults to `False`.
case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`.
Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI
subcommands.
Expand Down Expand Up @@ -932,6 +948,7 @@ def __init__(
cli_use_class_docs_for_groups: bool | None = None,
cli_exit_on_error: bool | None = None,
cli_prefix: str | None = None,
cli_implicit_flags: bool | None = None,
case_sensitive: bool | None = True,
root_parser: Any = None,
parse_args_method: Callable[..., Any] | None = ArgumentParser.parse_args,
Expand Down Expand Up @@ -975,6 +992,11 @@ def __init__(
if cli_prefix.startswith('.') or cli_prefix.endswith('.') or not cli_prefix.replace('.', '').isidentifier(): # type: ignore
raise SettingsError(f'CLI settings source prefix is invalid: {cli_prefix}')
self.cli_prefix += '.'
self.cli_implicit_flags = (
cli_implicit_flags
if cli_implicit_flags is not None
else settings_cls.model_config.get('cli_implicit_flags', False)
)

case_sensitive = case_sensitive if case_sensitive is not None else True
if not case_sensitive and root_parser is not None:
Expand Down Expand Up @@ -1281,6 +1303,23 @@ def _get_resolved_names(
resolved_names = [resolved_name.lower() for resolved_name in resolved_names]
return tuple(dict.fromkeys(resolved_names)), is_alias_path_only

def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
if _CliImplicitFlag in field_info.metadata:
cli_flag_name = 'CliImplicitFlag'
elif _CliExplicitFlag in field_info.metadata:
cli_flag_name = 'CliExplicitFlag'
else:
return

if field_info.annotation is not bool:
raise SettingsError(f'{cli_flag_name} argument {model.__name__}.{field_name} is not of type bool')
elif sys.version_info < (3, 9) and (
field_info.default is PydanticUndefined and field_info.default_factory is None
):
raise SettingsError(
f'{cli_flag_name} argument {model.__name__}.{field_name} must have default for python versions < 3.9'
)

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_args, subcommand_args, optional_args = [], [], []
fields = (
Expand Down Expand Up @@ -1310,6 +1349,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
raise SettingsError(f'positional argument {model.__name__}.{field_name} has an alias')
positional_args.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
optional_args.append((field_name, field_info))
return positional_args + subcommand_args + optional_args

Expand Down Expand Up @@ -1457,6 +1497,8 @@ def _add_parser_args(
del kwargs['required']
arg_flag = ''

self._convert_bool_flag(kwargs, field_info, model_default)

if sub_models and kwargs.get('action') != 'append':
self._add_parser_submodels(
parser,
Expand Down Expand Up @@ -1486,6 +1528,22 @@ def _add_parser_args(
self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
return parser

def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None:
if kwargs['metavar'] == 'bool':
default = None
if field_info.default is not PydanticUndefined:
default = field_info.default
if model_default is not PydanticUndefined:
default = model_default
if sys.version_info >= (3, 9) or isinstance(default, bool):
if (self.cli_implicit_flags or _CliImplicitFlag in field_info.metadata) and (
_CliExplicitFlag not in field_info.metadata
):
del kwargs['metavar']
kwargs['action'] = (
BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}'
)

def _get_arg_names(
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], resolved_names: tuple[str, ...]
) -> list[str]:
Expand Down
74 changes: 73 additions & 1 deletion tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@
TomlConfigSettingsSource,
YamlConfigSettingsSource,
)
from pydantic_settings.sources import CliPositionalArg, CliSettingsSource, CliSubCommand, SettingsError
from pydantic_settings.sources import (
CliExplicitFlag,
CliImplicitFlag,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
SettingsError,
)

try:
import dotenv
Expand Down Expand Up @@ -3119,6 +3126,71 @@ class InvalidCliParseArgsType(BaseSettings, cli_parse_args='invalid type'):

InvalidCliParseArgsType()

with pytest.raises(SettingsError, match='CliExplicitFlag argument CliFlagNotBool.flag is not of type bool'):

class CliFlagNotBool(BaseSettings, cli_parse_args=True):
flag: CliExplicitFlag[int] = False

CliFlagNotBool()

if sys.version_info < (3, 9):
with pytest.raises(
SettingsError,
match='CliImplicitFlag argument CliFlag38NotOpt.flag must have default for python versions < 3.9',
):

class CliFlag38NotOpt(BaseSettings, cli_parse_args=True):
flag: CliImplicitFlag[bool]

CliFlag38NotOpt()


@pytest.mark.parametrize('enforce_required', [True, False])
def test_cli_bool_flags(monkeypatch, enforce_required):
if sys.version_info < (3, 9):

class ExplicitSettings(BaseSettings, cli_enforce_required=enforce_required):
explicit_req: bool
explicit_opt: bool = False
implicit_opt: CliImplicitFlag[bool] = False

class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_required=enforce_required):
explicit_req: bool
explicit_opt: CliExplicitFlag[bool] = False
implicit_opt: bool = False

expected = {
'explicit_req': True,
'explicit_opt': False,
'implicit_opt': False,
}

assert ExplicitSettings(_cli_parse_args=['--explicit_req=True']).model_dump() == expected
assert ImplicitSettings(_cli_parse_args=['--explicit_req=True']).model_dump() == expected
else:

class ExplicitSettings(BaseSettings, cli_enforce_required=enforce_required):
explicit_req: bool
explicit_opt: bool = False
implicit_req: CliImplicitFlag[bool]
implicit_opt: CliImplicitFlag[bool] = False

class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_required=enforce_required):
explicit_req: CliExplicitFlag[bool]
explicit_opt: CliExplicitFlag[bool] = False
implicit_req: bool
implicit_opt: bool = False

expected = {
'explicit_req': True,
'explicit_opt': False,
'implicit_req': True,
'implicit_opt': False,
}

assert ExplicitSettings(_cli_parse_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected
assert ImplicitSettings(_cli_parse_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected


def test_cli_avoid_json(capsys, monkeypatch):
class SubModel(BaseModel):
Expand Down
Loading