From c0b443139ad32f281509e916d6aa0f44bcea7e92 Mon Sep 17 00:00:00 2001 From: Jeremy Date: Fri, 15 Dec 2023 17:18:45 -0500 Subject: [PATCH 1/4] Simplify the help on incorrect positional arg Signed-off-by: Jeremy --- src/turnkeyml/cli/cli.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/turnkeyml/cli/cli.py b/src/turnkeyml/cli/cli.py index ef05573..60cb674 100644 --- a/src/turnkeyml/cli/cli.py +++ b/src/turnkeyml/cli/cli.py @@ -476,22 +476,21 @@ def check_extension(choices, file_name): first_arg = sys.argv[1] if first_arg not in subparsers.choices.keys() and "-h" not in first_arg: if "." in first_arg: + # User has provided a file as the first positional arg sys.argv.insert(1, "benchmark") else: + # User has provided a command as the first positional arg # Check how close we are from each of the valid options + # NOTE: if we are not close to a valid option, we will let + # argparse handle the error valid_options = list(subparsers.choices.keys()) close_matches = get_close_matches(first_arg, valid_options) - error_msg = f"Unexpected positional argument `turnkey {first_arg}`. " if close_matches: - error_msg += f"Did you mean `turnkey {close_matches[0]}`?" - else: - error_msg += ( - "The first positional argument must either be " - "an input file with the .py or .onnx file extension or " - f"one of the following commands: {valid_options}." + raise exceptions.ArgError( + f"Unexpected positional argument `turnkey {first_arg}`. " + f"Did you mean `turnkey {close_matches[0]}`?" ) - raise exceptions.ArgError(error_msg) args = parser.parse_args() From 37b81846fad563c3b6f3eea7dcb180e81524d5fe Mon Sep 17 00:00:00 2001 From: Jeremy Date: Fri, 15 Dec 2023 19:26:11 -0500 Subject: [PATCH 2/4] Broader definition of 'file' Signed-off-by: Jeremy --- src/turnkeyml/cli/cli.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/turnkeyml/cli/cli.py b/src/turnkeyml/cli/cli.py index 60cb674..2829396 100644 --- a/src/turnkeyml/cli/cli.py +++ b/src/turnkeyml/cli/cli.py @@ -472,17 +472,20 @@ def check_extension(choices, file_name): # on a target script. If the user doesn't provide a command, # we alter argv to insert the command for them. + # Special characters that indicate a string is a filename, not a command + file_chars = [".", "/", "\\", "*"] + if len(sys.argv) > 1: first_arg = sys.argv[1] if first_arg not in subparsers.choices.keys() and "-h" not in first_arg: - if "." in first_arg: + if any(char_to_check in first_arg for char_to_check in file_chars): # User has provided a file as the first positional arg sys.argv.insert(1, "benchmark") else: # User has provided a command as the first positional arg # Check how close we are from each of the valid options # NOTE: if we are not close to a valid option, we will let - # argparse handle the error + # argparse detect and raise the error valid_options = list(subparsers.choices.keys()) close_matches = get_close_matches(first_arg, valid_options) From 8e7eb59d4cae839d6f7f0471bd1a66d0cc147f60 Mon Sep 17 00:00:00 2001 From: Jeremy Date: Fri, 15 Dec 2023 19:45:04 -0500 Subject: [PATCH 3/4] polishing Signed-off-by: Jeremy --- src/turnkeyml/cli/cli.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/turnkeyml/cli/cli.py b/src/turnkeyml/cli/cli.py index 2829396..411b533 100644 --- a/src/turnkeyml/cli/cli.py +++ b/src/turnkeyml/cli/cli.py @@ -21,11 +21,6 @@ def error(self, message): self.print_help() sys.exit(2) - def print_cache_help(self): - print("Error: a cache command is required") - self.print_help() - sys.exit(2) - def print_version(_): """ @@ -90,11 +85,11 @@ def main(): # Parser for the "benchmark" command ####################################### - def check_extension(choices, file_name): + def check_extension(choices, file_name, error_func): _, extension = os.path.splitext(file_name.split("::")[0]) if extension[1:].lower() not in choices: - raise exceptions.ArgError( - f"input_files must end with .py, .onnx, or .txt (got '{file_name}')" + error_func( + f"input_files must end with .py, .onnx, or .txt (got '{file_name}')\n" ) return file_name @@ -107,7 +102,9 @@ def check_extension(choices, file_name): "input_files", nargs="+", help="One or more script (.py), ONNX (.onnx), or input list (.txt) files to be benchmarked", - type=lambda file: check_extension(("py", "onnx", "txt"), file), + type=lambda file: check_extension( + ("py", "onnx", "txt"), file, benchmark_parser.error + ), ) slurm_or_processes_group = benchmark_parser.add_mutually_exclusive_group() @@ -491,7 +488,7 @@ def check_extension(choices, file_name): if close_matches: raise exceptions.ArgError( - f"Unexpected positional argument `turnkey {first_arg}`. " + f"Unexpected command `turnkey {first_arg}`. " f"Did you mean `turnkey {close_matches[0]}`?" ) From ff8cbd0bbad2dabb1d0f7af70d93225633bc3176 Mon Sep 17 00:00:00 2001 From: Jeremy Date: Mon, 18 Dec 2023 12:15:55 -0500 Subject: [PATCH 4/4] Update test Signed-off-by: Jeremy --- test/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cli.py b/test/cli.py index a1bba92..900934c 100644 --- a/test/cli.py +++ b/test/cli.py @@ -722,7 +722,7 @@ def test_016_skip_compiled(self): def test_017_invalid_file_type(self): # Ensure that we get an error when running turnkey with invalid input_files - with self.assertRaises(exceptions.ArgError): + with self.assertRaises(SystemExit): testargs = ["turnkey", "gobbledegook"] with patch.object(sys, "argv", flatten(testargs)): turnkeycli()