Skip to content

Commit

Permalink
Add Cmds class for multiple functions / commands (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
oir authored Nov 24, 2024
1 parent b4cc7b7 commit 2c33010
Show file tree
Hide file tree
Showing 23 changed files with 423 additions and 70 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,8 @@ venv/
.coverage
coverage.lcov

# Mypy
.mypy_cache/

# hatchling autogenerates version
/startle/_version.py
49 changes: 49 additions & 0 deletions examples/calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from startle import start


def add(a: int, b: int) -> None:
"""
Add two numbers.
Args:
a: The first number.
b: The second number.
"""
print(f"{a} + {b} = {a + b}")


def sub(a: int, b: int) -> None:
"""
Subtract two numbers.
Args:
a: The first number.
b: The second number
"""
print(f"{a} - {b} = {a - b}")


def mul(a: int, b: int) -> None:
"""
Multiply two numbers.
Args:
a: The first number.
b: The second number.
"""
print(f"{a} * {b} = {a * b}")


def div(a: int, b: int) -> None:
"""
Divide two numbers.
Args:
a: The dividend.
b: The divisor.
"""
print(f"{a} / {b} = {a / b}")


if __name__ == "__main__":
start([add, sub, mul, div])
27 changes: 27 additions & 0 deletions examples/cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Example invocations (for fish shell):
❯ python examples/cat.py examples/wc.py examples/cat.py --delim=\n===\n\n
❯ python examples/cat.py --delim=\n===\n\n examples/cat.py examples/wc.py
"""

from pathlib import Path

from startle import start


def cat(files: list[Path], /, *, delim: str = "") -> None:
"""
Concatenate files with an optional delimiter.
Args:
files: The files to concatenate.
delim: The delimiter to use.
"""

for i, file in enumerate(files):
if i:
print(delim, end="")
print(file.read_text(), end="")


start(cat)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ select = ["E4", "E7", "E9", "F", "I"]

[tool.pytest.ini_options]
addopts = "--cov=startle --cov-report=term-missing --cov-report=lcov"
python_files = ["tests/*.py", "tests/**/*.py"]

[tool.coverage.run]
omit = ["startle/_version.py"]
omit = ["startle/_version.py"]
68 changes: 68 additions & 0 deletions startle/_start.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import sys
from typing import Callable, TypeVar

from rich.console import Console

from .args import Args
from .cmds import Cmds
from .error import ParserConfigError, ParserOptionError, ParserValueError
from .inspect import make_args

T = TypeVar("T")


def start(
obj: Callable | list[Callable], args: list[str] | None = None, caught: bool = True
):
if isinstance(obj, list):
return _start_cmds(obj, args, caught)
else:
return _start_func(obj, args, caught)


def _start_func(
func: Callable[..., T], args: list[str] | None = None, caught: bool = True
) -> T:
"""
Expand Down Expand Up @@ -52,3 +64,59 @@ def start(
raise SystemExit(1)
else:
raise e


def _start_cmds(
funcs: list[Callable], cli_args: list[str] | None = None, caught: bool = True
):
""" """

def cmd_name(func: Callable) -> str:
return func.__name__.replace("_", "-")

def prog_name(func: Callable) -> str:
return f"{sys.argv[0]} {cmd_name(func)}"

cmd2func: dict[str, Callable] = {cmd_name(func): func for func in funcs}

try:
# first, make Cmds object from the functions
cmds = Cmds(
{
cmd_name(func): make_args(func, program_name=prog_name(func))
for func in funcs
}
)
except ParserConfigError as e:
if caught:
console = Console()
console.print(f"[bold red]Error:[/bold red] [red]{e}[/red]\n")
raise SystemExit(1)
else:
raise e

try:
# then, parse the arguments from the CLI
args: Args | None = None
cmd, args = cmds.parse(cli_args)

# then turn the parsed arguments into function arguments
f_args, f_kwargs = args.make_func_args()

# finally, call the function with the arguments
func = cmd2func[cmd]
return func(*f_args, **f_kwargs)
except (ParserOptionError, ParserValueError) as e:
if caught:
console = Console()
console.print(f"[bold red]Error:[/bold red] [red]{e}[/red]\n")
if args: # error happened after parsing the command
args.print_help(console, usage_only=True)
else: # error happened before parsing the command
cmds.print_help(console, usage_only=True)
console.print(
"\n[dim]For more information, run with [green][b]-?[/b]|[b]--help[/b][/green].[/dim]"
)
raise SystemExit(1)
else:
raise e
47 changes: 23 additions & 24 deletions startle/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class Args:
"""

brief: str = ""
program_name: str = ""

_positional_args: list[Arg] = field(default_factory=list)
_named_args: list[Arg] = field(default_factory=list)
_name2idx: dict[str, int] = field(default_factory=dict)
# note that _name2idx is many to one, because a name can be both short and long

_var_args: Arg | None = None # remaining unk args for functions with *args
_var_kwargs: Arg | None = None # remaining unk kwargs for functions with **kwargs
Expand Down Expand Up @@ -274,30 +276,27 @@ def var(opt: Arg | str) -> str:

return positional_args, named_args

def parse(self, args: list[str] | None = None) -> "Args":
def parse(self, cli_args: list[str] | None = None) -> "Args":
"""
Parse the command-line arguments.
Args:
args: The arguments to parse. If None, uses the arguments from the CLI.
cli_args: The arguments to parse. If None, uses the arguments from the CLI.
Returns:
Self, for chaining.
"""
if args is not None:
self._parse(args)
if cli_args is not None:
self._parse(cli_args)
else:
self._parse(sys.argv[1:])
return self

def print_help(
self, console=None, program_name: str | None = None, usage_only: bool = False
) -> None:
def print_help(self, console=None, usage_only: bool = False) -> None:
"""
Print the help message to the console.
Args:
console: A rich console to print to. If None, uses the default console.
program_name: The name of the program to use in the help message.
usage_only: Whether to print only the usage line.
"""
import sys
Expand All @@ -306,7 +305,7 @@ def print_help(
from rich.table import Table
from rich.text import Text

name = program_name or sys.argv[0]
name = self.program_name or sys.argv[0]

positional_only = [
arg
Expand Down Expand Up @@ -388,22 +387,30 @@ def help(arg: Arg) -> Text:
)
return helptext

# print brief if it exists
console = console or Console()
if self.brief and not usage_only:
console.print(self.brief + "\n")

# then print usage line
console.print(Text("Usage:", style=sty_title))
console.print(
Text(f" {name} ")
+ Text(" ").join([usage(arg, "usage line") for arg in positional_only])
+ Text(" ")
+ Text(" ").join(
[usage(opt, "usage line") for opt in positional_and_named + named_only]
)
usage_line = Text(f" {name}")
pos_only_str = Text(" ").join(
[usage(arg, "usage line") for arg in positional_only]
)
if pos_only_str:
usage_line += Text(" ") + pos_only_str
named_str = Text(" ").join(
[usage(opt, "usage line") for opt in positional_and_named + named_only]
)
if named_str:
usage_line += Text(" ") + named_str
console.print(usage_line)

if usage_only:
return

# then print help message for each argument
if positional_only + positional_and_named + named_only:
console.print(Text("\nwhere", style=sty_title))

Expand All @@ -427,11 +434,3 @@ def help(arg: Arg) -> Text:
)

console.print(table)

def __repr__(self) -> str:
rval = "<Args object>\n"
for arg in self._positional_args:
rval += f" <positional> {arg.metavar}: {arg._value}\n"
for arg in self._named_args:
rval += f" <named> {arg.name.long}: {arg._value}\n"
return rval
Loading

0 comments on commit 2c33010

Please sign in to comment.