Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the help on incorrect positional arg #75

Merged
merged 5 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions src/turnkeyml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def error(self, message):
self.print_help()
sys.exit(2)

def print_cache_help(self):
print("Error: a cache command is required")
self.print_help()
sys.exit(2)

jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved

def print_version(_):
"""
Expand Down Expand Up @@ -90,11 +85,11 @@ def main():
# Parser for the "benchmark" command
#######################################

def check_extension(choices, file_name):
def check_extension(choices, file_name, error_func):
_, extension = os.path.splitext(file_name.split("::")[0])
if extension[1:].lower() not in choices:
raise exceptions.ArgError(
f"input_files must end with .py, .onnx, or .txt (got '{file_name}')"
error_func(
f"input_files must end with .py, .onnx, or .txt (got '{file_name}')\n"
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
)
return file_name

Expand All @@ -109,7 +104,9 @@ def check_extension(choices, file_name):
"input_files",
nargs="+",
help="One or more script (.py), ONNX (.onnx), or input list (.txt) files to be benchmarked",
type=lambda file: check_extension(("py", "onnx", "txt"), file),
type=lambda file: check_extension(
("py", "onnx", "txt"), file, benchmark_parser.error
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
),
)

toolchain_select_group = benchmark_parser.add_argument_group(
Expand Down Expand Up @@ -500,26 +497,28 @@ def check_extension(choices, file_name):
# on a target script. If the user doesn't provide a command,
# we alter argv to insert the command for them.

# Special characters that indicate a string is a filename, not a command
file_chars = [".", "/", "\\", "*"]
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved

if len(sys.argv) > 1:
first_arg = sys.argv[1]
if first_arg not in subparsers.choices.keys() and "-h" not in first_arg:
if "." in first_arg:
if any(char_to_check in first_arg for char_to_check in file_chars):
# User has provided a file as the first positional arg
sys.argv.insert(1, "benchmark")
else:
# User has provided a command as the first positional arg
# Check how close we are from each of the valid options
# NOTE: if we are not close to a valid option, we will let
# argparse detect and raise the error
valid_options = list(subparsers.choices.keys())
close_matches = get_close_matches(first_arg, valid_options)

error_msg = f"Unexpected positional argument `turnkey {first_arg}`. "
if close_matches:
error_msg += f"Did you mean `turnkey {close_matches[0]}`?"
else:
error_msg += (
"The first positional argument must either be "
"an input file with the .py or .onnx file extension or "
f"one of the following commands: {valid_options}."
raise exceptions.ArgError(
f"Unexpected command `turnkey {first_arg}`. "
f"Did you mean `turnkey {close_matches[0]}`?"
)
raise exceptions.ArgError(error_msg)
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved

args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def test_016_skip_compiled(self):

def test_017_invalid_file_type(self):
# Ensure that we get an error when running turnkey with invalid input_files
with self.assertRaises(exceptions.ArgError):
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(SystemExit):
testargs = ["turnkey", "gobbledegook"]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
Expand Down
Loading