diff --git a/pyproject.toml b/pyproject.toml index afd27eb..65816c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "tomli; python_version<'3.11'", "typing_extensions; python_version<'3.11'", "wsproto >=0.14.0", + "rich-click >=1.8.3,<2.0.0", ] [project.optional-dependencies] diff --git a/src/anycorn/__main__.py b/src/anycorn/__main__.py index 10d02a7..fc9faf5 100644 --- a/src/anycorn/__main__.py +++ b/src/anycorn/__main__.py @@ -1,15 +1,13 @@ from __future__ import annotations -import argparse import ssl import sys -import warnings + +import rich_click as click from .config import Config from .run import run -sentinel = object() - def _load_config(config_path: str | None) -> Config: if config_path is None: @@ -22,283 +20,309 @@ def _load_config(config_path: str | None) -> Config: return Config.from_toml(config_path) -def main(sys_args: list[str] | None = None) -> int: - parser = argparse.ArgumentParser() - parser.add_argument( - "application", help="The application to dispatch to as path.to.module:instance.path" - ) - parser.add_argument("--access-log", help="Deprecated, see access-logfile", default=sentinel) - parser.add_argument( - "--access-logfile", - help="The target location for the access log, use `-` for stdout", - default=sentinel, - ) - parser.add_argument( - "--access-logformat", - help="The log format for the access log, see help docs", - default=sentinel, - ) - parser.add_argument( - "--backlog", help="The maximum number of pending connections", type=int, default=sentinel - ) - parser.add_argument( - "-b", - "--bind", - dest="binds", - help=""" The TCP host/address to bind to. Should be either host:port, host, - unix:path or fd://num, e.g. 127.0.0.1:5000, 127.0.0.1, - unix:/tmp/socket or fd://33 respectively. """, - default=[], - action="append", - ) - parser.add_argument("--ca-certs", help="Path to the SSL CA certificate file", default=sentinel) - parser.add_argument("--certfile", help="Path to the SSL certificate file", default=sentinel) - parser.add_argument("--cert-reqs", help="See verify mode argument", type=int, default=sentinel) - parser.add_argument("--ciphers", help="Ciphers to use for the SSL setup", default=sentinel) - parser.add_argument( - "-c", - "--config", - help="Location of a TOML config file, or when prefixed with `file:` a Python file, or when prefixed with `python:` a Python module.", # noqa: E501 - default=None, - ) - parser.add_argument( - "--debug", - help="Enable debug mode, i.e. extra logging and checks", - action="store_true", - default=sentinel, - ) - parser.add_argument("--error-log", help="Deprecated, see error-logfile", default=sentinel) - parser.add_argument( - "--error-logfile", - "--log-file", - dest="error_logfile", - help="The target location for the error log, use `-` for stderr", - default=sentinel, - ) - parser.add_argument( - "--graceful-timeout", - help="""Time to wait after SIGTERM or Ctrl-C for any remaining requests (tasks) - to complete.""", - default=sentinel, - type=int, - ) - parser.add_argument( - "--read-timeout", - help="""Seconds to wait before timing out reads on TCP sockets""", - default=sentinel, - type=int, - ) - parser.add_argument( - "--max-requests", - help="""Maximum number of requests a worker will process before restarting""", - default=sentinel, - type=int, - ) - parser.add_argument( - "--max-requests-jitter", - help="This jitter causes the max-requests per worker to be " - "randomized by randint(0, max_requests_jitter)", - default=sentinel, - type=int, - ) - parser.add_argument( - "-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int - ) - parser.add_argument( - "-k", - "--worker-class", - dest="worker_class", - help="The type of worker to use. Options include asyncio and trio.", - default=sentinel, - ) - parser.add_argument( - "--keep-alive", - help="Seconds to keep inactive connections alive for", - default=sentinel, - type=int, - ) - parser.add_argument("--keyfile", help="Path to the SSL key file", default=sentinel) - parser.add_argument( - "--keyfile-password", help="Password to decrypt the SSL key file", default=sentinel - ) - parser.add_argument( - "--insecure-bind", - dest="insecure_binds", - help="""The TCP host/address to bind to. SSL options will not apply to these binds. - See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs. - """, - default=[], - action="append", - ) - parser.add_argument( - "--log-config", - help=""""A Python logging configuration file. This can be prefixed with - 'json:' or 'toml:' to load the configuration from a file in - that format. Default is the logging ini format.""", - default=sentinel, - ) - parser.add_argument( - "--log-level", help="The (error) log level, defaults to info", default=sentinel - ) - parser.add_argument( - "-p", "--pid", help="Location to write the PID (Program ID) to.", default=sentinel - ) - # FIXME - # parser.add_argument( - # "--quic-bind", - # dest="quic_binds", - # help="""The UDP/QUIC host/address to bind to. See *bind* for formatting - # options. - # """, - # default=[], - # action="append", - # ) - parser.add_argument( - "--reload", - help="Enable automatic reloads on code changes", - action="store_true", - default=sentinel, - ) - parser.add_argument( - "--root-path", help="The setting for the ASGI root_path variable", default=sentinel - ) - parser.add_argument( - "--server-name", - dest="server_names", - help="""The hostnames that can be served, requests to different hosts - will be responded to with 404s. - """, - default=[], - action="append", - ) - parser.add_argument( - "--statsd-host", help="The host:port of the statsd server", default=sentinel - ) - parser.add_argument("--statsd-prefix", help="Prefix for all statsd messages", default="") - parser.add_argument( - "-m", - "--umask", - help="The permissions bit mask to use on any unix sockets.", - default=sentinel, - type=int, - ) - parser.add_argument( - "-u", "--user", help="User to own any unix sockets.", default=sentinel, type=int - ) - - def _convert_verify_mode(value: str) -> ssl.VerifyMode: - try: - return ssl.VerifyMode[value] - except KeyError: - raise argparse.ArgumentTypeError(f"'{value}' is not a valid verify mode") - - parser.add_argument( - "--verify-mode", - help="SSL verify mode for peer's certificate, see ssl.VerifyMode enum for possible values.", - type=_convert_verify_mode, - default=sentinel, - ) - parser.add_argument( - "--websocket-ping-interval", - help="""If set this is the time in seconds between pings sent to the client. - This can be used to keep the websocket connection alive.""", - default=sentinel, - type=int, - ) - parser.add_argument( - "-w", - "--workers", - dest="workers", - help="The number of workers to spawn and use", - default=sentinel, - type=int, - ) - args = parser.parse_args(sys_args or sys.argv[1:]) - config = _load_config(args.config) - config.application_path = args.application +@click.command( + help="Start the server and dispatch to the APPLICATION as path.to.module:instance.path." +) +@click.argument( + "application", +) +@click.option( + "--access-logfile", + help="The target location for the access log, use `-` for stdout", +) +@click.option( + "--access-logformat", + help="The log format for the access log, see help docs", +) +@click.option("--backlog", type=int, help="The maximum number of pending connections") +@click.option( + "-b", + "--bind", + "binds", + help="The TCP host/address to bind to. Should be either host:port, host, " + "unix:path or fd://num, e.g. 127.0.0.1:5000, 127.0.0.1, " + "unix:/tmp/socket or fd://33 respectively.", + default=[], + multiple=True, +) +@click.option( + "--ca-certs", + help="Path to the SSL CA certificate file", +) +@click.option( + "--certfile", + help="Path to the SSL certificate file", +) +@click.option( + "--cert-reqs", + type=int, + help="See verify mode argument", +) +@click.option( + "--ciphers", + help="Ciphers to use for the SSL setup", +) +@click.option( + "-c", + "--config", + help="Location of a TOML config file, or when prefixed with `file:` a Python file, " + "or when prefixed with `python:` a Python module.", +) +@click.option( + "--debug", + help="Enable debug mode, i.e. extra logging and checks", + is_flag=True, +) +@click.option( + "--error-logfile", + "--log-file", + "error_logfile", + help="The target location for the error log, use `-` for stderr", +) +@click.option( + "--graceful-timeout", + help="Time to wait after SIGTERM or Ctrl-C for any remaining requests (tasks) to complete.", + type=int, +) +@click.option( + "--read-timeout", + help="Seconds to wait before timing out reads on TCP sockets", + type=int, +) +@click.option( + "--max-requests", + help="Maximum number of requests a worker will process before restarting", + type=int, +) +@click.option( + "--max-requests-jitter", + help="This jitter causes the max-requests per worker to be " + "randomized by randint(0, max_requests_jitter)", + type=int, +) +@click.option( + "-g", + "--group", + help="Group to own any unix sockets.", + type=int, +) +@click.option( + "-k", + "--worker-class", + help="The type of worker to use. Options include asyncio and trio.", + type=click.Choice(("asyncio", "trio")), +) +@click.option( + "--keep-alive", + help="Seconds to keep inactive connections alive for", + type=int, +) +@click.option( + "--keyfile", + help="Path to the SSL key file", +) +@click.option( + "--keyfile-password", + help="Password to decrypt the SSL key file", +) +@click.option( + "--insecure-bind", + "insecure_binds", + help="The TCP host/address to bind to. SSL options will not apply to these binds. " + "See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs.", + default=[], + multiple=True, +) +@click.option( + "--log-config", + help="A Python logging configuration file. This can be prefixed with " + "'json:' or 'toml:' to load the configuration from a file in " + " that format. Default is the logging ini format.", +) +@click.option( + "--log-level", + help="The (error) log level, defaults to info", +) +@click.option( + "-p", + "--pid", + help="Location to write the PID (Program ID) to.", +) +# FIXME +# parser.add_argument( +# "--quic-bind", +# dest="quic_binds", +# help="""The UDP/QUIC host/address to bind to. See *bind* for formatting +# options. +# """, +# default=[], +# action="append", +# ) +@click.option( + "--reload", + help="Enable automatic reloads on code changes", + is_flag=True, +) +@click.option( + "--root-path", + help="The setting for the ASGI root_path variable", +) +@click.option( + "--server-name", + "server_names", + help="The hostnames that can be served, requests to different hosts " + "will be responded to with 404s.", + default=[], + multiple=True, +) +@click.option( + "--statsd-host", + help="The host:port of the statsd server", +) +@click.option( + "--statsd-prefix", + help="Prefix for all statsd messages", + default="", +) +@click.option( + "-m", + "--umask", + help="The permissions bit mask to use on any unix sockets.", + type=int, +) +@click.option( + "-u", + "--user", + help="User to own any unix sockets.", + type=int, +) +@click.option( + "--verify-mode", + help="SSL verify mode for peer's certificate, see ssl.VerifyMode enum for possible values.", + type=click.Choice(("CERT_NONE", "CERT_OPTIONAL", "CERT_REQUIRED")), +) +@click.option( + "--websocket-ping-interval", + help="If set this is the time in seconds between pings sent to the client. " + "This can be used to keep the websocket connection alive.", + type=int, +) +@click.option( + "-w", + "--workers", + help="The number of workers to spawn and use", + type=int, +) +def main( + application: str, + access_logfile: str | None, + access_logformat: str | None, + backlog: int | None, + binds: list[str], + ca_certs: str | None, + certfile: str | None, + cert_reqs: int | None, + ciphers: str | None, + config: str | None, + debug: bool, + error_logfile: str | None, + graceful_timeout: int | None, + read_timeout: int | None, + max_requests: int | None, + max_requests_jitter: int | None, + group: int | None, + worker_class: str | None, + keep_alive: int | None, + keyfile: str | None, + keyfile_password: str | None, + insecure_binds: list[str], + log_config: str | None, + log_level: str | None, + pid: str | None, + reload: bool, + root_path: str | None, + server_names: list[str], + statsd_host: str | None, + statsd_prefix: str, + umask: int | None, + user: int | None, + verify_mode: str | None, + websocket_ping_interval: int | None, + workers: int | None, +) -> int: + config = _load_config(config) + config.application_path = application - if args.log_level is not sentinel: - config.loglevel = args.log_level - if args.access_logformat is not sentinel: - config.access_log_format = args.access_logformat - if args.access_log is not sentinel: - warnings.warn( - "The --access-log argument is deprecated, use `--access-logfile` instead", - DeprecationWarning, - ) - config.accesslog = args.access_log - if args.access_logfile is not sentinel: - config.accesslog = args.access_logfile - if args.backlog is not sentinel: - config.backlog = args.backlog - if args.ca_certs is not sentinel: - config.ca_certs = args.ca_certs - if args.certfile is not sentinel: - config.certfile = args.certfile - if args.cert_reqs is not sentinel: - config.cert_reqs = args.cert_reqs - if args.ciphers is not sentinel: - config.ciphers = args.ciphers - if args.debug is not sentinel: - config.debug = args.debug - if args.error_log is not sentinel: - warnings.warn( - "The --error-log argument is deprecated, use `--error-logfile` instead", - DeprecationWarning, - ) - config.errorlog = args.error_log - if args.error_logfile is not sentinel: - config.errorlog = args.error_logfile - if args.graceful_timeout is not sentinel: - config.graceful_timeout = args.graceful_timeout - if args.read_timeout is not sentinel: - config.read_timeout = args.read_timeout - if args.group is not sentinel: - config.group = args.group - if args.keep_alive is not sentinel: - config.keep_alive_timeout = args.keep_alive - if args.keyfile is not sentinel: - config.keyfile = args.keyfile - if args.keyfile_password is not sentinel: - config.keyfile_password = args.keyfile_password - if args.log_config is not sentinel: - config.logconfig = args.log_config - if args.max_requests is not sentinel: - config.max_requests = args.max_requests - if args.max_requests_jitter is not sentinel: - config.max_requests_jitter = args.max_requests - if args.pid is not sentinel: - config.pid_path = args.pid - if args.root_path is not sentinel: - config.root_path = args.root_path - if args.reload is not sentinel: - config.use_reloader = args.reload - if args.statsd_host is not sentinel: - config.statsd_host = args.statsd_host - if args.statsd_prefix is not sentinel: - config.statsd_prefix = args.statsd_prefix - if args.umask is not sentinel: - config.umask = args.umask - if args.user is not sentinel: - config.user = args.user - if args.worker_class is not sentinel: - config.worker_class = args.worker_class - if args.verify_mode is not sentinel: - config.verify_mode = args.verify_mode - if args.websocket_ping_interval is not sentinel: - config.websocket_ping_interval = args.websocket_ping_interval - if args.workers is not sentinel: - config.workers = args.workers + if log_level is not None: + config.loglevel = log_level + if access_logformat is not None: + config.access_log_format = access_logformat + if access_logfile is not None: + config.accesslog = access_logfile + if backlog is not None: + config.backlog = backlog + if ca_certs is not None: + config.ca_certs = ca_certs + if certfile is not None: + config.certfile = certfile + if cert_reqs is not None: + config.cert_reqs = cert_reqs + if ciphers is not None: + config.ciphers = ciphers + if debug is not None: + config.debug = debug + if error_logfile is not None: + config.errorlog = error_logfile + if graceful_timeout is not None: + config.graceful_timeout = graceful_timeout + if read_timeout is not None: + config.read_timeout = read_timeout + if group is not None: + config.group = group + if keep_alive is not None: + config.keep_alive_timeout = keep_alive + if keyfile is not None: + config.keyfile = keyfile + if keyfile_password is not None: + config.keyfile_password = keyfile_password + if log_config is not None: + config.logconfig = log_config + if max_requests is not None: + config.max_requests = max_requests + if max_requests_jitter is not None: + config.max_requests_jitter = max_requests + if pid is not None: + config.pid_path = pid + if root_path is not None: + config.root_path = root_path + if reload is not None: + config.use_reloader = reload + if statsd_host is not None: + config.statsd_host = statsd_host + if statsd_prefix is not None: + config.statsd_prefix = statsd_prefix + if umask is not None: + config.umask = umask + if user is not None: + config.user = user + if worker_class is not None: + config.worker_class = worker_class + if verify_mode is not None: + config.verify_mode = ssl.VerifyMode[verify_mode] + if websocket_ping_interval is not None: + config.websocket_ping_interval = websocket_ping_interval + if workers is not None: + config.workers = workers - if len(args.binds) > 0: - config.bind = args.binds - if len(args.insecure_binds) > 0: - config.insecure_bind = args.insecure_binds + if len(binds) > 0: + config.bind = binds + if len(insecure_binds) > 0: + config.insecure_bind = insecure_binds # FIXME # if len(args.quic_binds) > 0: # config.quic_bind = args.quic_binds - if len(args.server_names) > 0: - config.server_names = args.server_names + if len(server_names) > 0: + config.server_names = server_names return run(config) diff --git a/tests/test___main__.py b/tests/test___main__.py index ea82d13..f63c0d2 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -8,6 +8,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from anycorn.config import Config +from click.testing import CliRunner def test_load_config_none() -> None: @@ -53,12 +54,15 @@ def test_load_config(monkeypatch: MonkeyPatch) -> None: def test_main_cli_override( flag: str, set_value: str, config_key: str, monkeypatch: MonkeyPatch ) -> None: + runner = CliRunner() run_multiple = Mock() monkeypatch.setattr(anycorn.__main__, "run", run_multiple) path = os.path.join(os.path.dirname(__file__), "assets/config_ssl.py") raw_config = Config.from_pyfile(path) - anycorn.__main__.main(["--config", f"file:{path}", flag, str(set_value), "asgi:App"]) + runner.invoke( + anycorn.__main__.main, ["--config", f"file:{path}", flag, str(set_value), "asgi:App"] + ) run_multiple.assert_called() config = run_multiple.call_args_list[0][0][0] @@ -73,11 +77,12 @@ def test_main_cli_override( def test_verify_mode_conversion(monkeypatch: MonkeyPatch) -> None: + runner = CliRunner() run_multiple = Mock() monkeypatch.setattr(anycorn.__main__, "run", run_multiple) - with pytest.raises(SystemExit): - anycorn.__main__.main(["--verify-mode", "CERT_UNKNOWN", "asgi:App"]) + result = runner.invoke(anycorn.__main__.main, ["--verify-mode", "CERT_UNKNOWN", "asgi:App"]) + assert isinstance(result.exception, SystemExit) - anycorn.__main__.main(["--verify-mode", "CERT_REQUIRED", "asgi:App"]) + runner.invoke(anycorn.__main__.main, ["--verify-mode", "CERT_REQUIRED", "asgi:App"]) run_multiple.assert_called()