Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put the subcmd name in the validate_output error message [#1539] #1540

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions augur/curate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def register_parser(parent_subparsers):
return parser


def validate_records(records: Iterable[dict], is_input: bool) -> Iterable[dict]:
def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) -> Iterable[dict]:
"""
Validate that the provided *records* all have the same fields.
Uses the keys of the first record to check against all other records.
Expand All @@ -131,6 +131,10 @@ def validate_records(records: Iterable[dict], is_input: bool) -> Iterable[dict]:
----------
records: iterable of dict

subcmd_name: str
The name of the subcommand whose output is being validated; used in
error messages displayed to the user.

is_input: bool
Whether the provided records come directly from user provided input
"""
Expand All @@ -140,8 +144,8 @@ def validate_records(records: Iterable[dict], is_input: bool) -> Iterable[dict]:
else:
# Hopefully users should not run into this error as it means we are
# not uniformly adding/removing fields from records
error_message += dedent("""\
Something unexpected happened during the augur curate command.
error_message += dedent(f"""\
Something unexpected happened during the augur curate {subcmd_name} command.
To report this, please open a new issue including the original command:
<https://github.com/nextstrain/augur/issues/new/choose>
""")
Expand Down Expand Up @@ -213,14 +217,17 @@ def run(args):
input files can be provided via the command line options `--metadata` and `--fasta`.
See the command's help message for more details."""))

# Get the name of the subcmd being run
subcmd_name = args.subcommand

# Validate records have the same input fields
validated_input_records = validate_records(records, True)
validated_input_records = validate_records(records, subcmd_name, True)

# Run subcommand to get modified records
modified_records = getattr(args, SUBCOMMAND_ATTRIBUTE).run(args, validated_input_records)

# Validate modified records have the same output fields
validated_output_records = validate_records(modified_records, False)
validated_output_records = validate_records(modified_records, subcmd_name, False)

# Output modified records
# First output FASTA, since the write fasta function yields the records again
Expand Down
47 changes: 47 additions & 0 deletions tests/io/test_curate_validate_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
from augur.curate import validate_records
from augur.errors import AugurError


@pytest.fixture
def good_records():
return [
{"geo_loc_name": "Canada/Vancouver"},
{"geo_loc_name": "Canada/Vancouver"},
]


@pytest.fixture
def bad_records():
return [
{"geo_loc_name": "Canada/Vancouver"},
{"geo_loc_name2": "Canada/Vancouver"},
]


class TestCurateValidateRecords:
def test_validate_input(self, good_records):
validated_records = validate_records(good_records, "test_subcmd", True)
assert list(validated_records) == good_records, "good input records validate"

def test_validate_output(self, good_records):
validated_records = validate_records(good_records, "test_subcmd", False)

assert list(validated_records) == good_records, "good output records validate"

def test_validate_bad_records(self, bad_records):
with pytest.raises(AugurError) as e:
list(validate_records(bad_records, "test_subcmd", True))
assert str(e.value).startswith(
"Records do not have the same fields!"
), "bad input records throw exception with expected message"

def test_validate_bad_output(self, bad_records):
with pytest.raises(AugurError) as e:
list(validate_records(bad_records, "test_subcmd", False))
assert str(e.value).startswith(
"Records do not have the same fields!"
), "bad output records throw exception with expected message"
assert (
"test_subcmd" in str(e.value)
), "bad output records throw exception with subcmd name in the message"
Loading