Skip to content

Commit

Permalink
WIP - allow resolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshadfield committed Oct 2, 2024
1 parent 5034ea4 commit cbcc232
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 10 deletions.
1 change: 1 addition & 0 deletions ingest/defaults/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,4 @@ grouping:
- authors
- abbr_authors
- institution
resolutions: defaults/segment_resolutions.yaml
27 changes: 27 additions & 0 deletions ingest/defaults/segment_resolutions.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
PICK_SEGMENT:
- strain: TRVL9760
accession: KP026181 # matches the metadata for the other segments for this strain
segment: S
PICK_FIELD:
- strain: H498913
field: date
accession: HQ830423 # 2 accessions have the same date of 1988, one has 1990. So we choose majority rule (as an example).
# following authors are very similar, pick the most complete looking
- strain: LET-2083
field: authors
accession: PP477309
- strain: LET-2102
field: authors
accession: PP477310
- strain: LET-2099
field: authors
accession: PP477311
- strain: LET-2088
field: authors
accession: PP477312
- strain: LET-2116
field: authors
accession: PP477313
- strain: LET-2093
field: authors
accession: PP477314
2 changes: 2 additions & 0 deletions ingest/rules/curate.smk
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ rule merge_metadata:
rule group_metadata:
input:
metadata="data/metadata_merged.tsv",
resolutions=config["grouping"]["resolutions"],
output:
metadata="results/metadata.tsv"
params:
Expand All @@ -161,6 +162,7 @@ rule group_metadata:
--metadata {input.metadata} \
--common-strain-fields {params.common_strain_fields} \
--segments {params.segments} \
--resolutions {input.resolutions} \
--output-metadata {output.metadata}
"""

Expand Down
78 changes: 68 additions & 10 deletions ingest/scripts/group_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import argparse
import csv
import yaml
from collections import defaultdict
from sys import stderr
from typing import TypedDict,Any
from sys import stderr, exit
from typing import Any
from typing_extensions import TypedDict, NotRequired

def parse_args():
parser = argparse.ArgumentParser(description = __doc__)
Expand All @@ -29,6 +31,8 @@ def parse_args():
help="Segment names")
parser.add_argument('--metadata', metavar="TSV", required=True, type=str,
help="Input metadata file. ID column='accession'")
parser.add_argument('--resolutions', metavar='YAML', required=False,
help="Rules to resolve conflicts when grouping")
parser.add_argument('--output-metadata', metavar="TSV", required=True, type=str,
help="Input metadata file. ID column='strain'")
return parser.parse_args()
Expand All @@ -41,6 +45,20 @@ def group_by_strain(filename: str) -> dict[str,list]:
strains[row['strain']].append(row)
return strains


class Resolutions(TypedDict):
PICK_SEGMENT: NotRequired[list[dict]]
PICK_FIELD: NotRequired[list[dict]]

def parse_resolutions_yaml(fname:str) -> Resolutions:
with open(fname) as fh:
try:
resolutions = yaml.safe_load(fh)
except yaml.YAMLError as e:
print(e)
exit(2)
return resolutions

def log(msg)->None:
# Currently we just dump to STDERR, but this should be formalised
print(msg, file=stderr)
Expand All @@ -55,7 +73,36 @@ def get_segment(strain:str, row:dict[str,Any], segment_names: list[str])->str:
raise AssertionError(f"Accession '{accession}' (strain '{strain}') mapped to multiple segments: {', '.join(segments_present)}. Skipping this accession.")
return next(iter(segments_present))

def assign_segments(strain_name:str, rows:list, segment_names:list[str])->dict[str,dict]|None:

def resolve_segment(resolutions: Resolutions, strain:str, segment:str, accessions: list[str]) -> str|None:
rules = [el for el in resolutions.get('PICK_SEGMENT', []) if el.get('strain')==strain and el.get('segment')==segment]
if len(rules)==0:
return None
if len(rules)>1:
log(f"Malformed resolutions YAML - multiple PICK_SEGMENT blocks for strain={strain} segment={segment}")
exit(2)
rule = rules[0]
if rule['accession'] not in accessions:
log(f"ERROR! A PICK_SEGMENT resolution for strain {strain} for segment {segment} specified an accession which wasn't in the metadata.")
return None
return rule['accession']

def resolve_field(resolutions: Resolutions, strain:str, field:str, accessions: list[str]) -> str|None:
# TODO - this is essentially identical to `resolve_segment`. If they don't diverge, we should consolidate.
rules = [el for el in resolutions.get('PICK_FIELD', []) if el.get('strain')==strain and el.get('field')==field]
if len(rules)==0:
return None
if len(rules)>1:
log(f"Malformed resolutions YAML - multiple PICK_FIELD blocks for strain={strain} field={field}")
exit(2)
rule = rules[0]
if rule['accession'] not in accessions:
log(f"ERROR! A PICK_FIELD resolution for strain {strain} for field {field} specified an accession which wasn't in the metadata.")
return None
return rule['accession']


def assign_segments(strain_name:str, rows:list, segment_names:list[str], resolutions: Resolutions)->dict[str,dict]|None:
"""
Given rows (assigned to a strain) assign each to a segment. Error if
(1) The same row (sequence) is assigned to multiple sequences
Expand All @@ -71,13 +118,18 @@ def assign_segments(strain_name:str, rows:list, segment_names:list[str])->dict[s
log(e)
continue # ignore this row
rows_by_segment[segment_name].append(row)
# Drop any segments with more than one matching accession / sequence.
# Drop any segments with more than one matching accession / sequence unless there's a rule to resolve it
for segment_name, seg_rows in rows_by_segment.items():
accessions = [r['accession'] for r in seg_rows]
if len(seg_rows)==1:
segments[segment_name] = seg_rows[0]
elif len(seg_rows)>1:
if (accession := resolve_segment(resolutions, strain_name, segment_name, accessions)):
segments[segment_name] = next(iter([row for row in seg_rows if row['accession']==accession]))
log(f"Resolving '{strain_name}' to use accession {accession} for segment {segment_name} ")
continue
log(f"Strain '{strain_name}' had multiple accessions for segment {segment_name}. "
f"Accessions: {', '.join([r['accession'] for r in seg_rows])}. "
f"Accessions: {', '.join(accessions)}. "
"Skipping this segment.")
continue # ignore this segment (other segments for this strain may be OK)

Expand All @@ -95,7 +147,7 @@ class HeaderInfo(TypedDict):
segment_specific: list[dict[str,str]]


def pick_from_values(strain_name:str, field_name:str, rows:list, allow_empty=True)->str:
def pick_from_values(strain_name:str, field_name:str, rows:list, resolutions: Resolutions, allow_empty=True)->str:
values = set(row[field_name] for row in rows)
if allow_empty and "" in values and len(values)!=1:
values.remove("")
Expand All @@ -107,6 +159,11 @@ def pick_from_values(strain_name:str, field_name:str, rows:list, allow_empty=Tru
# continue, and use the error message printing below
pass

if (accession:=resolve_field(resolutions, strain_name, field_name, [row['accession'] for row in rows])):
value = next(iter([row[field_name] for row in rows if row['accession']==accession]))
log(f"Resolving '{strain_name}' to use {field_name}={value}")
return value

# want to print out helpful messages about disagreement, so order by most commonly observed
obs = defaultdict(list)
for row in rows:
Expand All @@ -126,8 +183,8 @@ def resolve_mismatch_dates(strain_name:str, values:set[str])->str:
raise ValueMatchingError()


def make_wide(strain: str, rows: list, segment_names: list[str], header_info:HeaderInfo) -> dict[str,str]|None:
segments = assign_segments(strain, rows, segment_names)
def make_wide(strain: str, rows: list, segment_names: list[str], header_info:HeaderInfo, resolutions: Resolutions) -> dict[str,str]|None:
segments = assign_segments(strain, rows, segment_names, resolutions)
if not segments:
return None

Expand All @@ -139,7 +196,7 @@ def make_wide(strain: str, rows: list, segment_names: list[str], header_info:Hea
observed_mismatches = False
for field_name in header_info['common']:
try:
metadata[field_name] = pick_from_values(strain, field_name, list(segments.values()))
metadata[field_name] = pick_from_values(strain, field_name, list(segments.values()), resolutions)
except ValueMatchingError as e:
log(e)
observed_mismatches = True
Expand Down Expand Up @@ -177,13 +234,14 @@ def header(segments:list[str], tsv_fields:list[str], common_strain_fields:list[s

if __name__=="__main__":
args = parse_args()
resolutions = parse_resolutions_yaml(args.resolutions) if args.resolutions else {}

strains = group_by_strain(args.metadata)

header_info = header(args.segments, list(next(iter(strains.values()))[0].keys()), args.common_strain_fields)

collapsed = [row
for row in [make_wide(strain, rows, args.segments, header_info) for strain,rows in strains.items()]
for row in [make_wide(strain, rows, args.segments, header_info, resolutions) for strain,rows in strains.items()]
if row is not None]
log("If any errors have been printed above, then those strains will have been dropped.")

Expand Down

0 comments on commit cbcc232

Please sign in to comment.