diff --git a/.gitignore b/.gitignore index b2f8a61..d9d484e 100644 --- a/.gitignore +++ b/.gitignore @@ -24,5 +24,8 @@ venv/ .coverage coverage.lcov +# Mypy +.mypy_cache/ + # hatchling autogenerates version /startle/_version.py \ No newline at end of file diff --git a/examples/calc.py b/examples/calc.py new file mode 100644 index 0000000..3f5ec6f --- /dev/null +++ b/examples/calc.py @@ -0,0 +1,49 @@ +from startle import start + + +def add(a: int, b: int) -> None: + """ + Add two numbers. + + Args: + a: The first number. + b: The second number. + """ + print(f"{a} + {b} = {a + b}") + + +def sub(a: int, b: int) -> None: + """ + Subtract two numbers. + + Args: + a: The first number. + b: The second number + """ + print(f"{a} - {b} = {a - b}") + + +def mul(a: int, b: int) -> None: + """ + Multiply two numbers. + + Args: + a: The first number. + b: The second number. + """ + print(f"{a} * {b} = {a * b}") + + +def div(a: int, b: int) -> None: + """ + Divide two numbers. + + Args: + a: The dividend. + b: The divisor. + """ + print(f"{a} / {b} = {a / b}") + + +if __name__ == "__main__": + start([add, sub, mul, div]) diff --git a/examples/cat.py b/examples/cat.py new file mode 100644 index 0000000..0d81cd6 --- /dev/null +++ b/examples/cat.py @@ -0,0 +1,27 @@ +""" +Example invocations (for fish shell): + ❯ python examples/cat.py examples/wc.py examples/cat.py --delim=\n===\n\n + ❯ python examples/cat.py --delim=\n===\n\n examples/cat.py examples/wc.py +""" + +from pathlib import Path + +from startle import start + + +def cat(files: list[Path], /, *, delim: str = "") -> None: + """ + Concatenate files with an optional delimiter. + + Args: + files: The files to concatenate. + delim: The delimiter to use. + """ + + for i, file in enumerate(files): + if i: + print(delim, end="") + print(file.read_text(), end="") + + +start(cat) diff --git a/pyproject.toml b/pyproject.toml index 07b840d..eb91ec9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ select = ["E4", "E7", "E9", "F", "I"] [tool.pytest.ini_options] addopts = "--cov=startle --cov-report=term-missing --cov-report=lcov" +python_files = ["tests/*.py", "tests/**/*.py"] [tool.coverage.run] -omit = ["startle/_version.py"] \ No newline at end of file +omit = ["startle/_version.py"] diff --git a/startle/_start.py b/startle/_start.py index 1393c9e..b9083c3 100644 --- a/startle/_start.py +++ b/startle/_start.py @@ -1,7 +1,10 @@ +import sys from typing import Callable, TypeVar from rich.console import Console +from .args import Args +from .cmds import Cmds from .error import ParserConfigError, ParserOptionError, ParserValueError from .inspect import make_args @@ -9,6 +12,15 @@ def start( + obj: Callable | list[Callable], args: list[str] | None = None, caught: bool = True +): + if isinstance(obj, list): + return _start_cmds(obj, args, caught) + else: + return _start_func(obj, args, caught) + + +def _start_func( func: Callable[..., T], args: list[str] | None = None, caught: bool = True ) -> T: """ @@ -52,3 +64,59 @@ def start( raise SystemExit(1) else: raise e + + +def _start_cmds( + funcs: list[Callable], cli_args: list[str] | None = None, caught: bool = True +): + """ """ + + def cmd_name(func: Callable) -> str: + return func.__name__.replace("_", "-") + + def prog_name(func: Callable) -> str: + return f"{sys.argv[0]} {cmd_name(func)}" + + cmd2func: dict[str, Callable] = {cmd_name(func): func for func in funcs} + + try: + # first, make Cmds object from the functions + cmds = Cmds( + { + cmd_name(func): make_args(func, program_name=prog_name(func)) + for func in funcs + } + ) + except ParserConfigError as e: + if caught: + console = Console() + console.print(f"[bold red]Error:[/bold red] [red]{e}[/red]\n") + raise SystemExit(1) + else: + raise e + + try: + # then, parse the arguments from the CLI + args: Args | None = None + cmd, args = cmds.parse(cli_args) + + # then turn the parsed arguments into function arguments + f_args, f_kwargs = args.make_func_args() + + # finally, call the function with the arguments + func = cmd2func[cmd] + return func(*f_args, **f_kwargs) + except (ParserOptionError, ParserValueError) as e: + if caught: + console = Console() + console.print(f"[bold red]Error:[/bold red] [red]{e}[/red]\n") + if args: # error happened after parsing the command + args.print_help(console, usage_only=True) + else: # error happened before parsing the command + cmds.print_help(console, usage_only=True) + console.print( + "\n[dim]For more information, run with [green][b]-?[/b]|[b]--help[/b][/green].[/dim]" + ) + raise SystemExit(1) + else: + raise e diff --git a/startle/args.py b/startle/args.py index 881327b..2ecef6a 100644 --- a/startle/args.py +++ b/startle/args.py @@ -14,10 +14,12 @@ class Args: """ brief: str = "" + program_name: str = "" _positional_args: list[Arg] = field(default_factory=list) _named_args: list[Arg] = field(default_factory=list) _name2idx: dict[str, int] = field(default_factory=dict) + # note that _name2idx is many to one, because a name can be both short and long _var_args: Arg | None = None # remaining unk args for functions with *args _var_kwargs: Arg | None = None # remaining unk kwargs for functions with **kwargs @@ -274,30 +276,27 @@ def var(opt: Arg | str) -> str: return positional_args, named_args - def parse(self, args: list[str] | None = None) -> "Args": + def parse(self, cli_args: list[str] | None = None) -> "Args": """ Parse the command-line arguments. Args: - args: The arguments to parse. If None, uses the arguments from the CLI. + cli_args: The arguments to parse. If None, uses the arguments from the CLI. Returns: Self, for chaining. """ - if args is not None: - self._parse(args) + if cli_args is not None: + self._parse(cli_args) else: self._parse(sys.argv[1:]) return self - def print_help( - self, console=None, program_name: str | None = None, usage_only: bool = False - ) -> None: + def print_help(self, console=None, usage_only: bool = False) -> None: """ Print the help message to the console. Args: console: A rich console to print to. If None, uses the default console. - program_name: The name of the program to use in the help message. usage_only: Whether to print only the usage line. """ import sys @@ -306,7 +305,7 @@ def print_help( from rich.table import Table from rich.text import Text - name = program_name or sys.argv[0] + name = self.program_name or sys.argv[0] positional_only = [ arg @@ -388,22 +387,30 @@ def help(arg: Arg) -> Text: ) return helptext + # print brief if it exists console = console or Console() if self.brief and not usage_only: console.print(self.brief + "\n") + + # then print usage line console.print(Text("Usage:", style=sty_title)) - console.print( - Text(f" {name} ") - + Text(" ").join([usage(arg, "usage line") for arg in positional_only]) - + Text(" ") - + Text(" ").join( - [usage(opt, "usage line") for opt in positional_and_named + named_only] - ) + usage_line = Text(f" {name}") + pos_only_str = Text(" ").join( + [usage(arg, "usage line") for arg in positional_only] ) + if pos_only_str: + usage_line += Text(" ") + pos_only_str + named_str = Text(" ").join( + [usage(opt, "usage line") for opt in positional_and_named + named_only] + ) + if named_str: + usage_line += Text(" ") + named_str + console.print(usage_line) if usage_only: return + # then print help message for each argument if positional_only + positional_and_named + named_only: console.print(Text("\nwhere", style=sty_title)) @@ -427,11 +434,3 @@ def help(arg: Arg) -> Text: ) console.print(table) - - def __repr__(self) -> str: - rval = "\n" - for arg in self._positional_args: - rval += f" {arg.metavar}: {arg._value}\n" - for arg in self._named_args: - rval += f" {arg.name.long}: {arg._value}\n" - return rval diff --git a/startle/cmds.py b/startle/cmds.py new file mode 100644 index 0000000..4f5037c --- /dev/null +++ b/startle/cmds.py @@ -0,0 +1,113 @@ +import sys +from dataclasses import dataclass, field + +from .args import Args + + +@dataclass +class Cmds: + """ + A parser class which is a collection of Args objects paired with a command. + + Parsing is done by treating the first argument as a command and then + passing the remaining arguments to the Args object associated with that + command. + """ + + cmd_parsers: dict[str, Args] = field(default_factory=dict) + brief: str = "" + + def parse(self, cli_args: list[str] | None = None) -> tuple[str, Args]: + cli_args = cli_args or sys.argv[1:] + + if not cli_args: + print("Error: No command given!") + self.print_help() + raise SystemExit(1) + + cmd = cli_args[0] + if cmd in ["-?", "--help"]: + self.print_help() + raise SystemExit(0) + + if cmd not in self.cmd_parsers: + print(f"Error: Unknown command {cmd}!") + self.print_help() + raise SystemExit(1) + + args = self.cmd_parsers[cmd] + args.parse(cli_args[1:]) + return cmd, args + + def print_help( + self, console=None, program_name: str | None = None, usage_only: bool = False + ) -> None: + """ + Print the help message to the console. + + Args: + console: A rich console to print to. If None, uses the default console. + program_name: The name of the program to use in the help message. + usage_only: Whether to print only the usage line. + """ + import sys + + from rich.console import Console + from rich.table import Table + from rich.text import Text + + name = program_name or sys.argv[0] + + sty_pos_name = "bold" + sty_opt = "green" + sty_var = "blue" + sty_title = "bold underline dim" + sty_help = "italic" + + console = console or Console() + if self.brief and not usage_only: + console.print(self.brief + "\n") + + console.print( + Text.assemble( + "\n", + ("Usage:", sty_title), + "\n", + f" {name} ", + ("<", sty_var), + ("command", f"{sty_var} {sty_pos_name}"), + (">", sty_var), + " ", + ("", sty_var), + "\n", + ) + ) + + console.print(Text("Commands:", style=sty_title)) + + table = Table(show_header=False, box=None, padding=(0, 0, 0, 2)) + for cmd, args in self.cmd_parsers.items(): + brief = args.brief.split("\n\n")[0] + table.add_row( + Text(cmd, style=f"{sty_pos_name} {sty_var}"), + Text(brief, style=sty_help), + ) + console.print(table) + + console.print( + Text.assemble( + "\n", + ("Run ", "dim"), + "`", + name, + " ", + ("<", sty_var), + ("command", f"{sty_var} {sty_pos_name}"), + (">", sty_var), + " ", + ("--help", f"{sty_opt}"), + "`", + (" to see all command-specific options.", "dim"), + "\n", + ) + ) diff --git a/startle/inspect.py b/startle/inspect.py index 3c9e6a2..a0968d7 100644 --- a/startle/inspect.py +++ b/startle/inspect.py @@ -68,14 +68,14 @@ def _parse_docstring(func: Callable) -> tuple[str, dict[str, str]]: return brief, arg_helps -def make_args(func: Callable) -> Args: +def make_args(func: Callable, program_name: str = "") -> Args: # Get the signature of the function sig = inspect.signature(func) # Attempt to parse brief and arg descriptions from docstring brief, arg_helps = _parse_docstring(func) - args = Args(brief=brief) + args = Args(brief=brief, program_name=program_name) used_short_names = set() diff --git a/tests/test_help.py b/tests/test_help.py index 64d0063..e045fae 100644 --- a/tests/test_help.py +++ b/tests/test_help.py @@ -13,7 +13,7 @@ def check_help(f: Callable, program_name: str, expected: str): console = Console(width=120, highlight=False) with console.capture() as capture: - make_args(f).print_help(console, program_name) + make_args(f, program_name).print_help(console) result = capture.get() console = Console(width=120, highlight=False) diff --git a/tests/test_parser/__init__.py b/tests/test_parser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_parser/test_parse_common.py b/tests/test_parser/test_parse_common.py index f295e0a..1ca8296 100644 --- a/tests/test_parser/test_parse_common.py +++ b/tests/test_parser/test_parse_common.py @@ -1,11 +1,12 @@ import re -from _utils import check_args from pytest import mark, raises from startle.error import ParserConfigError, ParserOptionError, ParserValueError from startle.inspect import make_args +from ._utils import check_args + def hi_int(name: str = "john", /, *, count: int = 1) -> None: for _ in range(count): diff --git a/tests/test_parser/test_parse_enum.py b/tests/test_parser/test_parse_enum.py index bfae710..6576f23 100644 --- a/tests/test_parser/test_parse_enum.py +++ b/tests/test_parser/test_parse_enum.py @@ -2,11 +2,12 @@ from enum import Enum, IntEnum from typing import Callable -from _utils import Opt, Opts, check_args from pytest import mark, raises from startle.error import ParserOptionError, ParserValueError +from ._utils import Opt, Opts, check_args + def check(draw: Callable, shape: type[Enum], opt: Opt): check_args(draw, opt("shape", ["square"]), [shape.SQUARE], {}) diff --git a/tests/test_parser/test_parse_list.py b/tests/test_parser/test_parse_list.py index b195ccb..3237d80 100644 --- a/tests/test_parser/test_parse_list.py +++ b/tests/test_parser/test_parse_list.py @@ -1,10 +1,11 @@ from typing import Callable, List -from _utils import check_args from pytest import mark, raises from startle.error import ParserOptionError, ParserValueError +from ._utils import check_args + def add_int(*, numbers: list[int]) -> None: print(sum(numbers)) diff --git a/tests/test_parser/test_parse_literal.py b/tests/test_parser/test_parse_literal.py index a36211b..9f6f996 100644 --- a/tests/test_parser/test_parse_literal.py +++ b/tests/test_parser/test_parse_literal.py @@ -1,11 +1,12 @@ import re from typing import Callable, Literal -from _utils import Opt, Opts, check_args from pytest import mark, raises from startle.error import ParserValueError +from ._utils import Opt, Opts, check_args + def check(draw: Callable, opt: Opt): check_args(draw, opt("shape", ["square"]), ["square"], {}) diff --git a/tests/test_parser/test_parse_optional.py b/tests/test_parser/test_parse_optional.py index c0bf11a..c9812e4 100644 --- a/tests/test_parser/test_parse_optional.py +++ b/tests/test_parser/test_parse_optional.py @@ -1,8 +1,9 @@ from typing import Callable, Optional, Union -from _utils import Opt, Opts, check_args from pytest import mark +from ._utils import Opt, Opts, check_args + def hi1(msg: str | None = None) -> None: print(f"{msg or 'hi'}!") diff --git a/tests/test_parser/test_parse_set.py b/tests/test_parser/test_parse_set.py index 37c7ef4..ead8bf8 100644 --- a/tests/test_parser/test_parse_set.py +++ b/tests/test_parser/test_parse_set.py @@ -1,10 +1,11 @@ from typing import Callable, Set -from _utils import check_args from pytest import mark, raises from startle.error import ParserOptionError, ParserValueError +from ._utils import check_args + def add_int(*, numbers: set[int]) -> None: print(sum(numbers)) diff --git a/tests/test_parser/test_parse_tuple.py b/tests/test_parser/test_parse_tuple.py index 5da7520..4c01902 100644 --- a/tests/test_parser/test_parse_tuple.py +++ b/tests/test_parser/test_parse_tuple.py @@ -1,10 +1,11 @@ from typing import Callable, Tuple -from _utils import check_args from pytest import mark, raises from startle.error import ParserOptionError, ParserValueError +from ._utils import check_args + def add_int(*, numbers: tuple[int]) -> None: print(sum(numbers)) diff --git a/tests/test_parser/test_parse_unknown.py b/tests/test_parser/test_parse_unknown.py index adc2126..dc7a111 100644 --- a/tests/test_parser/test_parse_unknown.py +++ b/tests/test_parser/test_parse_unknown.py @@ -1,10 +1,11 @@ from typing import Callable -from _utils import check_args from pytest import mark, raises from startle.error import ParserOptionError +from ._utils import check_args + def hi_w_args(msg: str, n: int, *args) -> None: pass diff --git a/tests/test_parser/test_parse_unsupported_type.py b/tests/test_parser/test_parse_unsupported_type.py index d355adf..edb5935 100644 --- a/tests/test_parser/test_parse_unsupported_type.py +++ b/tests/test_parser/test_parse_unsupported_type.py @@ -1,11 +1,12 @@ import re from typing import Callable -from _utils import check_args from pytest import mark, raises from startle.error import ParserConfigError +from ._utils import check_args + class Spell: pass diff --git a/tests/test_start/__init__.py b/tests/test_start/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_start/_utils.py b/tests/test_start/_utils.py new file mode 100644 index 0000000..7302f7f --- /dev/null +++ b/tests/test_start/_utils.py @@ -0,0 +1,35 @@ +import sys +from typing import Callable + +from pytest import raises + +from startle import start + + +def run_w_explicit_args(func: Callable[..., None], args: list[str]) -> None: + start(func, args) + + +def run_w_sys_argv(func: Callable[..., None], args: list[str]) -> None: + old_argv = sys.argv[1:] + sys.argv[1:] = args + start(func) + sys.argv[1:] = old_argv + + +def check( + capsys, run: Callable, f: Callable | list[Callable], args: list[str], expected: str +) -> None: + run(f, args) + captured = capsys.readouterr() + assert captured.out == expected + + +def check_exits( + capsys, run: Callable, f: Callable, args: list[str], expected: str +) -> None: + with raises(SystemExit) as excinfo: + run(f, args) + assert str(excinfo.value) == "1" + captured = capsys.readouterr() + assert captured.out.startswith(expected) diff --git a/tests/test_start/test_start_cmds.py b/tests/test_start/test_start_cmds.py new file mode 100644 index 0000000..aa63001 --- /dev/null +++ b/tests/test_start/test_start_cmds.py @@ -0,0 +1,77 @@ +from typing import Callable + +from pytest import mark + +from ._utils import check, check_exits, run_w_explicit_args, run_w_sys_argv + + +def add(a: int, b: int) -> None: + """ + Add two numbers. + + Args: + a: The first number. + b: The second number. + """ + print(f"{a} + {b} = {a + b}") + + +def sub(a: int, b: int) -> None: + """ + Subtract two numbers. + + Args: + a: The first number. + b: The second number + """ + print(f"{a} - {b} = {a - b}") + + +def mul(a: int, b: int) -> None: + """ + Multiply two numbers. + + Args: + a: The first number. + b: The second number. + """ + print(f"{a} * {b} = {a * b}") + + +def div(a: int, b: int) -> None: + """ + Divide two numbers. + + Args: + a: The dividend. + b: The divisor. + """ + print(f"{a} / {b} = {a / b}") + + +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) +def test_calc(capsys, run: Callable) -> None: + check(capsys, run, [add, sub, mul, div], ["add", "2", "3"], "2 + 3 = 5\n") + check(capsys, run, [add, sub, mul, div], ["sub", "2", "3"], "2 - 3 = -1\n") + check(capsys, run, [add, sub, mul, div], ["mul", "2", "3"], "2 * 3 = 6\n") + check(capsys, run, [add, sub, mul, div], ["div", "6", "3"], "6 / 3 = 2.0\n") + + check_exits( + capsys, run, [add, sub, mul, div], ["2", "3"], "Error: Unknown command 2!\n" + ) + check_exits(capsys, run, [add, sub, mul, div], [], "Error: No command given!\n") + + check_exits( + capsys, + run, + [add, sub, mul, div], + ["add", "2", "3", "4"], + "Error: Unexpected positional argument: `4`!\n", + ) + check_exits( + capsys, + run, + [add, sub, mul, div], + ["sub", "2"], + "Error: Required option `b` is not provided!\n", + ) diff --git a/tests/test_start.py b/tests/test_start/test_start_func.py similarity index 77% rename from tests/test_start.py rename to tests/test_start/test_start_func.py index a77fe45..63ef497 100644 --- a/tests/test_start.py +++ b/tests/test_start/test_start_func.py @@ -1,36 +1,8 @@ -import sys from typing import Callable -from pytest import mark, raises +from pytest import mark -from startle import start - - -def run1(func: Callable[..., None], args: list[str]) -> None: - start(func, args) - - -def run2(func: Callable[..., None], args: list[str]) -> None: - old_argv = sys.argv[1:] - sys.argv[1:] = args - start(func) - sys.argv[1:] = old_argv - - -def check(capsys, run: Callable, f: Callable, args: list[str], expected: str) -> None: - run(f, args) - captured = capsys.readouterr() - assert captured.out == expected - - -def check_exits( - capsys, run: Callable, f: Callable, args: list[str], expected: str -) -> None: - with raises(SystemExit) as excinfo: - run(f, args) - assert str(excinfo.value) == "1" - captured = capsys.readouterr() - assert captured.out.startswith(expected) +from ._utils import check, check_exits, run_w_explicit_args, run_w_sys_argv def hi1(name: str, count: int = 1) -> None: @@ -64,7 +36,7 @@ def hi6(*, name: str, count: int = 1) -> None: @mark.parametrize("hi", [hi1, hi2, hi3, hi4, hi5, hi6]) -@mark.parametrize("run", [run1, run2]) +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) def test_hi(capsys, run: Callable, hi: Callable) -> None: if hi in [hi1, hi2, hi3, hi4]: check(capsys, run, hi, ["Alice"], "Hello, Alice!\n") @@ -116,7 +88,7 @@ def test_hi(capsys, run: Callable, hi: Callable) -> None: @mark.parametrize("hi", [hi1, hi2, hi3, hi4, hi5, hi6]) -@mark.parametrize("run", [run1, run2]) +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) def test_parse_err(capsys, run: Callable, hi: Callable) -> None: if hi in [hi1, hi5, hi6]: check_exits( @@ -146,7 +118,7 @@ def test_parse_err(capsys, run: Callable, hi: Callable) -> None: ) -@mark.parametrize("run", [run1, run2]) +@mark.parametrize("run", [run_w_explicit_args, run_w_sys_argv]) def test_config_err(capsys, run: Callable) -> None: def f(help: bool = False) -> None: pass