diff --git a/docs/index.md b/docs/index.md index 85d059d..65bfb58 100644 --- a/docs/index.md +++ b/docs/index.md @@ -868,6 +868,75 @@ options: """ ``` +#### CLI Boolean Flags + +Change whether boolean fields should be explicit or implicit by default using the `cli_implicit_flags` setting. By +default, boolean fields are "explicit", meaning a boolean value must be explicitly provided on the CLI, e.g. +`--flag=True`. Conversely, boolean fields that are "implicit" derive the value from the flag itself, e.g. +`--flag,--no-flag`, which removes the need for an explicit value to be passed. + +Additionally, the provided `CliImplicitFlag` and `CliExplicitFlag` annotations can be used for more granular control +when necessary. + +!!! note +For `python < 3.9`: + * The `--no-flag` option is not generated due to an underlying `argparse` limitation. + * The `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional bool fields. + +```py +from pydantic_settings import BaseSettings, CliExplicitFlag, CliImplicitFlag + + +class ExplicitSettings(BaseSettings, cli_parse_args=True): + """Boolean fields are explicit by default.""" + + explicit_req: bool + """ + --explicit_req bool (required) + """ + + explicit_opt: bool = False + """ + --explicit_opt bool (default: False) + """ + + # Booleans are explicit by default, so must override implicit flags with annotation + implicit_req: CliImplicitFlag[bool] + """ + --implicit_req, --no-implicit_req (required) + """ + + implicit_opt: CliImplicitFlag[bool] = False + """ + --implicit_opt, --no-implicit_opt (default: False) + """ + + +class ImplicitSettings(BaseSettings, cli_parse_args=True, cli_implicit_flags=True): + """With cli_implicit_flags=True, boolean fields are implicit by default.""" + + # Booleans are implicit by default, so must override explicit flags with annotation + explicit_req: CliExplicitFlag[bool] + """ + --explicit_req bool (required) + """ + + explicit_opt: CliExplicitFlag[bool] = False + """ + --explicit_opt bool (default: False) + """ + + implicit_req: bool + """ + --implicit_req, --no-implicit_req (required) + """ + + implicit_opt: bool = False + """ + --implicit_opt, --no-implicit_opt (default: False) + """ +``` + #### Change Whether CLI Should Exit on Error Change whether the CLI internal parser will exit on error or raise a `SettingsError` exception by using diff --git a/pydantic_settings/__init__.py b/pydantic_settings/__init__.py index c0d5f35..5f979ea 100644 --- a/pydantic_settings/__init__.py +++ b/pydantic_settings/__init__.py @@ -1,6 +1,8 @@ from .main import BaseSettings, SettingsConfigDict from .sources import ( AzureKeyVaultSettingsSource, + CliExplicitFlag, + CliImplicitFlag, CliPositionalArg, CliSettingsSource, CliSubCommand, @@ -24,6 +26,8 @@ 'CliSettingsSource', 'CliSubCommand', 'CliPositionalArg', + 'CliExplicitFlag', + 'CliImplicitFlag', 'InitSettingsSource', 'JsonConfigSettingsSource', 'PyprojectTomlConfigSettingsSource', diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 1dd4ac7..5433442 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -40,6 +40,7 @@ class SettingsConfigDict(ConfigDict, total=False): cli_use_class_docs_for_groups: bool cli_exit_on_error: bool cli_prefix: str + cli_implicit_flags: bool | None secrets_dir: str | Path | None json_file: PathType | None json_file_encoding: str | None @@ -114,6 +115,8 @@ class BaseSettings(BaseModel): _cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs. Defaults to `True`. _cli_prefix: The root parser command line arguments prefix. Defaults to "". + _cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags. + (e.g. --flag, --no-flag). Defaults to `False`. _secrets_dir: The secret files directory. Defaults to `None`. """ @@ -137,6 +140,7 @@ def __init__( _cli_use_class_docs_for_groups: bool | None = None, _cli_exit_on_error: bool | None = None, _cli_prefix: str | None = None, + _cli_implicit_flags: bool | None = None, _secrets_dir: str | Path | None = None, **values: Any, ) -> None: @@ -162,6 +166,7 @@ def __init__( _cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups, _cli_exit_on_error=_cli_exit_on_error, _cli_prefix=_cli_prefix, + _cli_implicit_flags=_cli_implicit_flags, _secrets_dir=_secrets_dir, ) ) @@ -211,6 +216,7 @@ def _settings_build_values( _cli_use_class_docs_for_groups: bool | None = None, _cli_exit_on_error: bool | None = None, _cli_prefix: str | None = None, + _cli_implicit_flags: bool | None = None, _secrets_dir: str | Path | None = None, ) -> dict[str, Any]: # Determine settings config values @@ -260,6 +266,9 @@ def _settings_build_values( _cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error') ) cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix') + cli_implicit_flags = ( + _cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags') + ) secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir') @@ -311,6 +320,7 @@ def _settings_build_values( cli_use_class_docs_for_groups=cli_use_class_docs_for_groups, cli_exit_on_error=cli_exit_on_error, cli_prefix=cli_prefix, + cli_implicit_flags=cli_implicit_flags, case_sensitive=case_sensitive, ) if cli_settings_source is None @@ -358,6 +368,7 @@ def _settings_build_values( cli_use_class_docs_for_groups=False, cli_exit_on_error=True, cli_prefix='', + cli_implicit_flags=False, json_file=None, json_file_encoding=None, yaml_file=None, diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 156a109..f423b0b 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -8,6 +8,9 @@ import typing import warnings from abc import ABC, abstractmethod + +if sys.version_info >= (3, 9): + from argparse import BooleanOptionalAction from argparse import SUPPRESS, ArgumentParser, Namespace, RawDescriptionHelpFormatter, _SubParsersAction from collections import deque from dataclasses import is_dataclass @@ -124,6 +127,14 @@ class _CliPositionalArg: pass +class _CliImplicitFlag: + pass + + +class _CliExplicitFlag: + pass + + class _CliInternalArgParser(ArgumentParser): def __init__(self, cli_exit_on_error: bool = True, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -138,6 +149,9 @@ def error(self, message: str) -> NoReturn: T = TypeVar('T') CliSubCommand = Annotated[Union[T, None], _CliSubCommand] CliPositionalArg = Annotated[T, _CliPositionalArg] +_CliBoolFlag = TypeVar('_CliBoolFlag', bound=bool) +CliImplicitFlag = Annotated[_CliBoolFlag, _CliImplicitFlag] +CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag] class EnvNoneType(str): @@ -905,6 +919,8 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]): cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs. Defaults to `True`. cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "". + cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags. + (e.g. --flag, --no-flag). Defaults to `False`. case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`. Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI subcommands. @@ -932,6 +948,7 @@ def __init__( cli_use_class_docs_for_groups: bool | None = None, cli_exit_on_error: bool | None = None, cli_prefix: str | None = None, + cli_implicit_flags: bool | None = None, case_sensitive: bool | None = True, root_parser: Any = None, parse_args_method: Callable[..., Any] | None = ArgumentParser.parse_args, @@ -975,6 +992,11 @@ def __init__( if cli_prefix.startswith('.') or cli_prefix.endswith('.') or not cli_prefix.replace('.', '').isidentifier(): # type: ignore raise SettingsError(f'CLI settings source prefix is invalid: {cli_prefix}') self.cli_prefix += '.' + self.cli_implicit_flags = ( + cli_implicit_flags + if cli_implicit_flags is not None + else settings_cls.model_config.get('cli_implicit_flags', False) + ) case_sensitive = case_sensitive if case_sensitive is not None else True if not case_sensitive and root_parser is not None: @@ -1281,6 +1303,23 @@ def _get_resolved_names( resolved_names = [resolved_name.lower() for resolved_name in resolved_names] return tuple(dict.fromkeys(resolved_names)), is_alias_path_only + def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None: + if _CliImplicitFlag in field_info.metadata: + cli_flag_name = 'CliImplicitFlag' + elif _CliExplicitFlag in field_info.metadata: + cli_flag_name = 'CliExplicitFlag' + else: + return + + if field_info.annotation is not bool: + raise SettingsError(f'{cli_flag_name} argument {model.__name__}.{field_name} is not of type bool') + elif sys.version_info < (3, 9) and ( + field_info.default is PydanticUndefined and field_info.default_factory is None + ): + raise SettingsError( + f'{cli_flag_name} argument {model.__name__}.{field_name} must have default for python versions < 3.9' + ) + def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]: positional_args, subcommand_args, optional_args = [], [], [] fields = ( @@ -1310,6 +1349,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] raise SettingsError(f'positional argument {model.__name__}.{field_name} has an alias') positional_args.append((field_name, field_info)) else: + self._verify_cli_flag_annotations(model, field_name, field_info) optional_args.append((field_name, field_info)) return positional_args + subcommand_args + optional_args @@ -1457,6 +1497,8 @@ def _add_parser_args( del kwargs['required'] arg_flag = '' + self._convert_bool_flag(kwargs, field_info, model_default) + if sub_models and kwargs.get('action') != 'append': self._add_parser_submodels( parser, @@ -1486,6 +1528,22 @@ def _add_parser_args( self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group) return parser + def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None: + if kwargs['metavar'] == 'bool': + default = None + if field_info.default is not PydanticUndefined: + default = field_info.default + if model_default is not PydanticUndefined: + default = model_default + if sys.version_info >= (3, 9) or isinstance(default, bool): + if (self.cli_implicit_flags or _CliImplicitFlag in field_info.metadata) and ( + _CliExplicitFlag not in field_info.metadata + ): + del kwargs['metavar'] + kwargs['action'] = ( + BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}' + ) + def _get_arg_names( self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], resolved_names: tuple[str, ...] ) -> list[str]: diff --git a/tests/test_settings.py b/tests/test_settings.py index e83be97..7bf0ba7 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -50,7 +50,14 @@ TomlConfigSettingsSource, YamlConfigSettingsSource, ) -from pydantic_settings.sources import CliPositionalArg, CliSettingsSource, CliSubCommand, SettingsError +from pydantic_settings.sources import ( + CliExplicitFlag, + CliImplicitFlag, + CliPositionalArg, + CliSettingsSource, + CliSubCommand, + SettingsError, +) try: import dotenv @@ -3119,6 +3126,71 @@ class InvalidCliParseArgsType(BaseSettings, cli_parse_args='invalid type'): InvalidCliParseArgsType() + with pytest.raises(SettingsError, match='CliExplicitFlag argument CliFlagNotBool.flag is not of type bool'): + + class CliFlagNotBool(BaseSettings, cli_parse_args=True): + flag: CliExplicitFlag[int] = False + + CliFlagNotBool() + + if sys.version_info < (3, 9): + with pytest.raises( + SettingsError, + match='CliImplicitFlag argument CliFlag38NotOpt.flag must have default for python versions < 3.9', + ): + + class CliFlag38NotOpt(BaseSettings, cli_parse_args=True): + flag: CliImplicitFlag[bool] + + CliFlag38NotOpt() + + +@pytest.mark.parametrize('enforce_required', [True, False]) +def test_cli_bool_flags(monkeypatch, enforce_required): + if sys.version_info < (3, 9): + + class ExplicitSettings(BaseSettings, cli_enforce_required=enforce_required): + explicit_req: bool + explicit_opt: bool = False + implicit_opt: CliImplicitFlag[bool] = False + + class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_required=enforce_required): + explicit_req: bool + explicit_opt: CliExplicitFlag[bool] = False + implicit_opt: bool = False + + expected = { + 'explicit_req': True, + 'explicit_opt': False, + 'implicit_opt': False, + } + + assert ExplicitSettings(_cli_parse_args=['--explicit_req=True']).model_dump() == expected + assert ImplicitSettings(_cli_parse_args=['--explicit_req=True']).model_dump() == expected + else: + + class ExplicitSettings(BaseSettings, cli_enforce_required=enforce_required): + explicit_req: bool + explicit_opt: bool = False + implicit_req: CliImplicitFlag[bool] + implicit_opt: CliImplicitFlag[bool] = False + + class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_required=enforce_required): + explicit_req: CliExplicitFlag[bool] + explicit_opt: CliExplicitFlag[bool] = False + implicit_req: bool + implicit_opt: bool = False + + expected = { + 'explicit_req': True, + 'explicit_opt': False, + 'implicit_req': True, + 'implicit_opt': False, + } + + assert ExplicitSettings(_cli_parse_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected + assert ImplicitSettings(_cli_parse_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected + def test_cli_avoid_json(capsys, monkeypatch): class SubModel(BaseModel):