Skip to content

Commit

Permalink
fix: apply some consistent coding style
Browse files Browse the repository at this point in the history
Since the generalized ingest PR has triggered some ingest/bin/scripts discussions, applied the following consistent coding styles across ingest scripts.

1) Separated out argument parsing into its own parse_args function
2) Separate out the main function
3) Apply snakefmt on python files

Any of the above style decisions are open to discussion. An inconsistent coding style within the same repository is like looking at a wall of pictures where some of the pictures are crooked.
  • Loading branch information
j23414 committed Apr 1, 2023
1 parent f281e2a commit 430ff78
Show file tree
Hide file tree
Showing 15 changed files with 555 additions and 337 deletions.
119 changes: 75 additions & 44 deletions ingest/bin/apply-geolocation-rules
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,44 @@ from collections import defaultdict
from sys import exit, stderr, stdin, stdout


def parse_args():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--region-field",
default="region",
help="Field that contains regions in NDJSON records.",
)
parser.add_argument(
"--country-field",
default="country",
help="Field that contains countries in NDJSON records.",
)
parser.add_argument(
"--division-field",
default="division",
help="Field that contains divisions in NDJSON records.",
)
parser.add_argument(
"--location-field",
default="location",
help="Field that contains location in NDJSON records.",
)
parser.add_argument(
"--geolocation-rules",
metavar="TSV",
required=True,
help="TSV file of geolocation rules with the format: "
+ "'<raw_geolocation><tab><annotated_geolocation>' where the raw and annotated geolocations "
+ "are formatted as '<region>/<country>/<division>/<location>'. "
+ "If creating a general rule, then the raw field value can be substituted with '*'."
+ "Lines starting with '#' will be ignored as comments."
+ "Trailing '#' will be ignored as comments.",
)
return parser.parse_args()


class CyclicGeolocationRulesError(Exception):
pass

Expand All @@ -29,50 +67,50 @@ def load_geolocation_rules(geolocation_rules_file):
}
"""
geolocation_rules = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
with open(geolocation_rules_file, 'r') as rules_fh:
with open(geolocation_rules_file, "r") as rules_fh:
for line in rules_fh:
# ignore comments
if line.lstrip()[0] == '#':
if line.lstrip()[0] == "#":
continue

row = line.strip('\n').split('\t')
row = line.strip("\n").split("\t")
# Skip lines that cannot be split into raw and annotated geolocations
if len(row) != 2:
print(
f"WARNING: Could not decode geolocation rule {line!r}.",
"Please make sure rules are formatted as",
"'region/country/division/location<tab>region/country/division/location'.",
file=stderr)
file=stderr,
)
continue

# remove trailing comments
row[-1] = row[-1].partition('#')[0].rstrip()
raw , annot = tuple( row[0].split('/') ) , tuple( row[1].split('/') )
row[-1] = row[-1].partition("#")[0].rstrip()
raw, annot = tuple(row[0].split("/")), tuple(row[1].split("/"))

# Skip lines where raw or annotated geolocations cannot be split into 4 fields
if len(raw) != 4:
print(
f"WARNING: Could not decode the raw geolocation {row[0]!r}.",
"Please make sure it is formatted as 'region/country/division/location'.",
file=stderr
file=stderr,
)
continue

if len(annot) != 4:
print(
f"WARNING: Could not decode the annotated geolocation {row[1]!r}.",
"Please make sure it is formatted as 'region/country/division/location'.",
file=stderr
file=stderr,
)
continue


geolocation_rules[raw[0]][raw[1]][raw[2]][raw[3]] = annot

return geolocation_rules


def get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal = None):
def get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal=None):
"""
Gets the annotated geolocation for the *raw_geolocation* in the provided
*geolocation_rules*.
Expand Down Expand Up @@ -107,13 +145,15 @@ def get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal
if isinstance(current_rules, dict):
next_traversal_target = raw_geolocation[len(rule_traversal)]
rule_traversal.append(next_traversal_target)
return get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal)
return get_annotated_geolocation(
geolocation_rules, raw_geolocation, rule_traversal
)

# We did not find any matching rule for the last traversal target
if current_rules is None:
# If we've used all general rules and we still haven't found a match,
# then there are no applicable rules for this geolocation
if all(value == '*' for value in rule_traversal):
if all(value == "*" for value in rule_traversal):
return None

# If we failed to find matching rule with a general rule as the last
Expand All @@ -122,14 +162,14 @@ def get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal
# [A, *, B, *] => [A, *, B]
# [A, B, *, *] => [A, B]
# [A, *, *, *] => [A]
if rule_traversal[-1] == '*':
if rule_traversal[-1] == "*":
# Find the index of the first of the consecutive '*' from the
# end of the rule_traversal
# [A, *, B, *] => first_consecutive_general_rule_index = 3
# [A, B, *, *] => first_consecutive_general_rule_index = 2
# [A, *, *, *] => first_consecutive_general_rule_index = 1
for index, field_value in reversed(list(enumerate(rule_traversal))):
if field_value == '*':
if field_value == "*":
first_consecutive_general_rule_index = index
else:
break
Expand All @@ -138,9 +178,11 @@ def get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal

# Set the final value to '*' in hopes that by moving to a general rule,
# we can find a matching rule.
rule_traversal[-1] = '*'
rule_traversal[-1] = "*"

return get_annotated_geolocation(geolocation_rules, raw_geolocation, rule_traversal)
return get_annotated_geolocation(
geolocation_rules, raw_geolocation, rule_traversal
)


def transform_geolocations(geolocation_rules, geolocation):
Expand All @@ -161,7 +203,9 @@ def transform_geolocations(geolocation_rules, geolocation):
continue_to_apply = True

while continue_to_apply:
annotated_values = get_annotated_geolocation(geolocation_rules, transformed_values)
annotated_values = get_annotated_geolocation(
geolocation_rules, transformed_values
)

# Stop applying rules if no annotated values were found
if annotated_values is None:
Expand All @@ -178,7 +222,7 @@ def transform_geolocations(geolocation_rules, geolocation):
new_values = list(transformed_values)
for index, value in enumerate(annotated_values):
# Keep original value if annotated value is '*'
if value != '*':
if value != "*":
new_values[index] = value

# Stop applying rules if this rule did not change the values,
Expand All @@ -191,44 +235,31 @@ def transform_geolocations(geolocation_rules, geolocation):
return transformed_values


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--region-field", default="region",
help="Field that contains regions in NDJSON records.")
parser.add_argument("--country-field", default="country",
help="Field that contains countries in NDJSON records.")
parser.add_argument("--division-field", default="division",
help="Field that contains divisions in NDJSON records.")
parser.add_argument("--location-field", default="location",
help="Field that contains location in NDJSON records.")
parser.add_argument("--geolocation-rules", metavar="TSV", required=True,
help="TSV file of geolocation rules with the format: " +
"'<raw_geolocation><tab><annotated_geolocation>' where the raw and annotated geolocations " +
"are formatted as '<region>/<country>/<division>/<location>'. " +
"If creating a general rule, then the raw field value can be substituted with '*'." +
"Lines starting with '#' will be ignored as comments." +
"Trailing '#' will be ignored as comments.")

args = parser.parse_args()

location_fields = [args.region_field, args.country_field, args.division_field, args.location_field]
if __name__ == "__main__":
args = parse_args()

location_fields = [
args.region_field,
args.country_field,
args.division_field,
args.location_field,
]

geolocation_rules = load_geolocation_rules(args.geolocation_rules)

for record in stdin:
record = json.loads(record)

try:
annotated_values = transform_geolocations(geolocation_rules, [record[field] for field in location_fields])
annotated_values = transform_geolocations(
geolocation_rules, [record[field] for field in location_fields]
)
except CyclicGeolocationRulesError as e:
print(e, file=stderr)
exit(1)

for index, field in enumerate(location_fields):
record[field] = annotated_values[index]

json.dump(record, stdout, allow_nan=False, indent=None, separators=',:')
json.dump(record, stdout, allow_nan=False, indent=None, separators=",:")
print()
16 changes: 11 additions & 5 deletions ingest/bin/csv-to-ndjson
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@ import csv
import json
from sys import stdin, stdout

# 200 MiB; default is 128 KiB
csv.field_size_limit(200 * 1024 * 1024)

for row in csv.DictReader(stdin):
json.dump(row, stdout, allow_nan = False, indent = None, separators = ',:')
print()
def main():
# Increase the maximum field size limit to 200 MiB; default is 128 KiB
csv.field_size_limit(200 * 1024 * 1024)

for row in csv.DictReader(stdin):
json.dump(row, stdout, allow_nan=False, indent=None, separators=",:")
print()


if __name__ == "__main__":
main()
81 changes: 54 additions & 27 deletions ingest/bin/fasta-to-ndjson
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,36 @@ import sys
from augur.io import read_sequences


if __name__ == '__main__':
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--fasta", required=True,
help="FASTA file to be transformed into NDJSON format")
parser.add_argument("--fields", nargs="+",
help="Fields in the FASTA header, listed in the same order as the header. " +
"These will be used as the keys in the final NDJSON output. " +
"One of the fields must be 'strain'. " +
"These cannot include the field 'sequence' as this field is reserved for the genomic sequence.")
parser.add_argument("--separator", default='|',
help="Field separator in the FASTA header")
parser.add_argument("--exclude", nargs="*",
parser.add_argument(
"--fasta", required=True, help="FASTA file to be transformed into NDJSON format"
)
parser.add_argument(
"--fields",
nargs="+",
help="Fields in the FASTA header, listed in the same order as the header. "
+ "These will be used as the keys in the final NDJSON output. "
+ "One of the fields must be 'strain'. "
+ "These cannot include the field 'sequence' as this field is reserved for the genomic sequence.",
)
parser.add_argument(
"--separator", default="|", help="Field separator in the FASTA header"
)
parser.add_argument(
"--exclude",
nargs="*",
help="List of fields to exclude from final NDJSON record. "
"These cannot include 'strain' or 'sequence'.")
"These cannot include 'strain' or 'sequence'.",
)

args = parser.parse_args()
return parser.parse_args()


def main():
args = parse_args()

fasta_fields = [field.lower() for field in args.fields]

Expand All @@ -44,43 +55,59 @@ if __name__ == '__main__':

passed_checks = True

if 'strain' not in fasta_fields:
if "strain" not in fasta_fields:
print("ERROR: FASTA fields must include a 'strain' field.", file=sys.stderr)
passed_checks = False

if 'sequence' in fasta_fields:
if "sequence" in fasta_fields:
print("ERROR: FASTA fields cannot include a 'sequence' field.", file=sys.stderr)
passed_checks = False

if 'strain' in exclude_fields:
print("ERROR: The field 'strain' cannot be excluded from the output.", file=sys.stderr)
if "strain" in exclude_fields:
print(
"ERROR: The field 'strain' cannot be excluded from the output.",
file=sys.stderr,
)
passed_checks = False

if 'sequence' in exclude_fields:
print("ERROR: The field 'sequence' cannot be excluded from the output.", file=sys.stderr)
if "sequence" in exclude_fields:
print(
"ERROR: The field 'sequence' cannot be excluded from the output.",
file=sys.stderr,
)
passed_checks = False

missing_fields = [field for field in exclude_fields if field not in fasta_fields]
if missing_fields:
print(f"ERROR: The following exclude fields do not match any FASTA fields: {missing_fields}", file=sys.stderr)
print(
f"ERROR: The following exclude fields do not match any FASTA fields: {missing_fields}",
file=sys.stderr,
)
passed_checks = False

if not passed_checks:
print("ERROR: Failed to parse FASTA file into NDJSON records.","See detailed errors above.", file=sys.stderr)
print(
"ERROR: Failed to parse FASTA file into NDJSON records.",
"See detailed errors above.",
file=sys.stderr,
)
sys.exit(1)

sequences = read_sequences(args.fasta)

for sequence in sequences:
field_values = [
value.strip()
for value in sequence.description.split(args.separator)
value.strip() for value in sequence.description.split(args.separator)
]
record = dict(zip(fasta_fields, field_values))
record['sequence'] = str(sequence.seq).upper()
record["sequence"] = str(sequence.seq).upper()

for field in exclude_fields:
del record[field]

json.dump(record, sys.stdout, allow_nan=False, indent=None, separators=',:')
json.dump(record, sys.stdout, allow_nan=False, indent=None, separators=",:")
print()


if __name__ == "__main__":
main()
Loading

0 comments on commit 430ff78

Please sign in to comment.