Skip to content

Commit

Permalink
Simplify the help on incorrect positional arg (#75)
Browse files Browse the repository at this point in the history
* Simplify the help on incorrect positional arg

Signed-off-by: Jeremy <jeremy.fowers@amd.com>

---------

Signed-off-by: Jeremy <jeremy.fowers@amd.com>
  • Loading branch information
jeremyfowers authored Dec 18, 2023
1 parent 68a3491 commit 21eacec
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
35 changes: 17 additions & 18 deletions src/turnkeyml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def error(self, message):
self.print_help()
sys.exit(2)

def print_cache_help(self):
print("Error: a cache command is required")
self.print_help()
sys.exit(2)


def print_version(_):
"""
Expand Down Expand Up @@ -90,11 +85,11 @@ def main():
# Parser for the "benchmark" command
#######################################

def check_extension(choices, file_name):
def check_extension(choices, file_name, error_func):
_, extension = os.path.splitext(file_name.split("::")[0])
if extension[1:].lower() not in choices:
raise exceptions.ArgError(
f"input_files must end with .py, .onnx, or .txt (got '{file_name}')"
error_func(
f"input_files must end with .py, .onnx, or .txt (got '{file_name}')\n"
)
return file_name

Expand All @@ -109,7 +104,9 @@ def check_extension(choices, file_name):
"input_files",
nargs="+",
help="One or more script (.py), ONNX (.onnx), or input list (.txt) files to be benchmarked",
type=lambda file: check_extension(("py", "onnx", "txt"), file),
type=lambda file: check_extension(
("py", "onnx", "txt"), file, benchmark_parser.error
),
)

toolchain_select_group = benchmark_parser.add_argument_group(
Expand Down Expand Up @@ -500,26 +497,28 @@ def check_extension(choices, file_name):
# on a target script. If the user doesn't provide a command,
# we alter argv to insert the command for them.

# Special characters that indicate a string is a filename, not a command
file_chars = [".", "/", "\\", "*"]

if len(sys.argv) > 1:
first_arg = sys.argv[1]
if first_arg not in subparsers.choices.keys() and "-h" not in first_arg:
if "." in first_arg:
if any(char_to_check in first_arg for char_to_check in file_chars):
# User has provided a file as the first positional arg
sys.argv.insert(1, "benchmark")
else:
# User has provided a command as the first positional arg
# Check how close we are from each of the valid options
# NOTE: if we are not close to a valid option, we will let
# argparse detect and raise the error
valid_options = list(subparsers.choices.keys())
close_matches = get_close_matches(first_arg, valid_options)

error_msg = f"Unexpected positional argument `turnkey {first_arg}`. "
if close_matches:
error_msg += f"Did you mean `turnkey {close_matches[0]}`?"
else:
error_msg += (
"The first positional argument must either be "
"an input file with the .py or .onnx file extension or "
f"one of the following commands: {valid_options}."
raise exceptions.ArgError(
f"Unexpected command `turnkey {first_arg}`. "
f"Did you mean `turnkey {close_matches[0]}`?"
)
raise exceptions.ArgError(error_msg)

args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def test_016_skip_compiled(self):

def test_017_invalid_file_type(self):
# Ensure that we get an error when running turnkey with invalid input_files
with self.assertRaises(exceptions.ArgError):
with self.assertRaises(SystemExit):
testargs = ["turnkey", "gobbledegook"]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
Expand Down

0 comments on commit 21eacec

Please sign in to comment.