Skip to content

Commit

Permalink
Merge branch 'master' into filter-ambiguous-min-max
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Aug 19, 2021
2 parents 65fa83b + be57dbe commit 49fa23e
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 61 deletions.
22 changes: 22 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@
## __NEXT__


## 13.0.0 (17 August 2021)

### Major Changes

* filter: Skip metadata records with ambiguous month information in the `date` column when grouping by month instead of randomly generating month values for those records. This change alters the behavior of the `filter` command for metadata with ambiguous month values. For these data, consider using `--group-by year` instead of `--group-by year month`. [#761][] (@huddlej)

### Features

* filter: When grouping by year or month, report the number of strains skipped due to ambiguous year and month both in the summary report at the end of filtering and in the `--output-log` contents [#761][] (@huddlej)

[#761]: https://github.com/nextstrain/augur/pull/761

## 12.1.1 (13 August 2021)

### Bug Fixes

* filter: Fix parsing of missing data in metadata [#758][] (@huddlej)
* filter: Fix probabilistic sampling with small values [#759][] (@huddlej)

[#758]: https://github.com/nextstrain/augur/pull/758
[#759]: https://github.com/nextstrain/augur/pull/759

## 12.1.0 (12 August 2021)

### Features
Expand Down
2 changes: 1 addition & 1 deletion augur/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '12.1.0'
__version__ = '13.0.0'


def is_augur_version_compatible(version):
Expand Down
188 changes: 143 additions & 45 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .utils import is_vcf as filename_is_vcf, read_vcf, read_strains, get_numerical_dates, to_numeric_date_min, to_numeric_date_max, run_shell_command, shquote, is_date_ambiguous

comment_char = '#'
MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS = 10

SEQUENCE_ONLY_FILTERS = (
"min_length",
Expand Down Expand Up @@ -821,23 +820,30 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
-------
dict :
A mapping of strain names to tuples corresponding to the values of the strain's group.
list :
A list of dictionaries with strains that were skipped from grouping and the reason why (see also: `apply_filters` output).
>>> strains = ["strain1", "strain2"]
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020-01-01", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by = ["region"]
>>> get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': ('Africa',), 'strain2': ('Europe',)}
>>> skipped_strains
[]
If we group by year or month, these groups are calculated from the date
string.
>>> group_by = ["year", "month"]
>>> get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))}
If we omit the grouping columns, the result will group by a dummy column.
>>> get_groups_for_subsampling(strains, metadata)
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
>>> group_by_strain
{'strain1': ('_dummy',), 'strain2': ('_dummy',)}
If we try to group by columns that don't exist, we get an error.
Expand All @@ -852,17 +858,41 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
grouping to continue and print a warning message to stderr.
>>> group_by = ["year", "month", "missing_column"]
>>> get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')}
If we group by year month and some records don't have that information in
their date fields, we should skip those records from the group output and
track which records were skipped for which reasons.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"])
>>> group_by_strain
{'strain2': (2020,)}
>>> skipped_strains
[{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_year', 'kwargs': ''}]
Similarly, if we group by month, we should skip records that don't have
month information in their date fields.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"])
>>> group_by_strain
{'strain2': ((2020, 2),)}
>>> skipped_strains
[{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}]
"""
if group_by:
groups = group_by
else:
groups = ("_dummy",)

group_by_strain = {}
skipped_strains = []
for strain in strains:
skip_strain = False
group = []
m = metadata.loc[strain].to_dict()
# collect group specifiers
Expand All @@ -875,20 +905,33 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
try:
year = int(m["date"].split('-')[0])
except:
print("WARNING: no valid year, skipping",strain, m["date"])
continue
skipped_strains.append({
"strain": strain,
"filter": "skip_group_by_with_ambiguous_year",
"kwargs": "",
})
skip_strain = True
break
if c=='month':
try:
month = int(m["date"].split('-')[1])
except:
month = random.randint(1,12)
skipped_strains.append({
"strain": strain,
"filter": "skip_group_by_with_ambiguous_month",
"kwargs": "",
})
skip_strain = True
break

group.append((year, month))
else:
group.append(year)
else:
group.append('unknown')

group_by_strain[strain] = tuple(group)
if not skip_strain:
group_by_strain[strain] = tuple(group)

# If we could not find any requested categories, we cannot complete subsampling.
distinct_groups = set(group_by_strain.values())
Expand Down Expand Up @@ -926,18 +969,13 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
file=sys.stderr,
)

return group_by_strain
return group_by_strain, skipped_strains


class PriorityQueue:
"""A priority queue implementation that automatically replaces lower priority
items in the heap with incoming higher priority items.
This implementation also allows the maximum size to be a fractional value
less than 1 in which case the heap size is sampled randomly from a Poisson
distribution with the given maximum size as the mean. This randomly sized
heap enables probabilistic subsampling.
Add a single record to a heap with a maximum of 2 records.
>>> queue = PriorityQueue(max_size=2)
Expand Down Expand Up @@ -971,28 +1009,12 @@ class PriorityQueue:
>>> list(queue.get_items())
[{'strain': 'strain4'}, {'strain': 'strain3'}]
Assign a fractional maximum size such that the corresponding queue limit is
sampled randomly from a Poisson distribution. For small values, we should
get a max size that is no more than 10 (this is an arbitrarily high number
above what we see for Poisson samples drawn with a mean of 0.1).
>>> queue = PriorityQueue(max_size=0.1)
>>> queue.max_size in set(range(10))
True
"""
def __init__(self, max_size):
"""Create a fixed size heap (priority queue) that allows the maximum size to be
calculated probabilistically from a Poisson process.
"""Create a fixed size heap (priority queue)
"""
# Fractional heap sizes indicate probabilistic sampling.
if max_size < 1.0:
random_generator = np.random.default_rng()
self.max_size = random_generator.poisson(max_size)
else:
self.max_size = max_size

self.max_size = max_size
self.heap = []
self.counter = itertools.count()

Expand Down Expand Up @@ -1027,10 +1049,53 @@ def get_items(self):
yield item


def priority_queue_factory(max_size):
"""Return a callable for a priority queue with the given arguments.
def create_queues_by_group(groups, max_size, max_attempts=100):
"""Create a dictionary of priority queues per group for the given maximum size.
When the maximum size is fractional, probabilistically sample the maximum
size from a Poisson distribution. Make at least the given number of maximum
attempts to create queues for which the sum of their maximum sizes is
greater than zero.
Create queues for two groups with a fixed maximum size.
>>> groups = ("2015", "2016")
>>> queues = create_queues_by_group(groups, 2)
>>> sum(queue.max_size for queue in queues.values())
4
Create queues for two groups with a fractional maximum size. Their total max
size should still be an integer value greater than zero.
>>> queues = create_queues_by_group(groups, 0.1)
>>> int(sum(queue.max_size for queue in queues.values())) > 0
True
"""
return lambda: PriorityQueue(max_size=max_size)
queues_by_group = {}
total_max_size = 0
attempts = 0

if max_size < 1.0:
random_generator = np.random.default_rng()

# For small fractional maximum sizes, it is possible to randomly select
# maximum queue sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create queues with
# maximum sizes greater than zero for at least one queue.
while total_max_size == 0 and attempts < max_attempts:
for group in groups:
if max_size < 1.0:
queue_max_size = random_generator.poisson(max_size)
else:
queue_max_size = max_size

queues_by_group[group] = PriorityQueue(queue_max_size)

total_max_size = sum(queue.max_size for queue in queues_by_group.values())
attempts += 1

return queues_by_group


def register_arguments(parser):
Expand Down Expand Up @@ -1137,6 +1202,14 @@ def validate_arguments(args):
file=sys.stderr)
return False

# If user requested grouping, confirm that other required inputs are provided, too.
if args.group_by and not any((args.sequences_per_group, args.subsample_max_sequences)):
print(
"ERROR: You must specify a number of sequences per group or maximum sequences to subsample.",
file=sys.stderr
)
return False

return True


Expand Down Expand Up @@ -1328,12 +1401,21 @@ def run(args):
# count the number of records per group. First, we need to get
# the groups for the given records.
try:
group_by_strain = get_groups_for_subsampling(
group_by_strain, skipped_strains = get_groups_for_subsampling(
seq_keep,
metadata,
group_by,
)

# Track strains skipped during grouping, so users know why those
# strains were excluded from the analysis.
for skipped_strain in skipped_strains:
filter_counts[(skipped_strain["filter"], skipped_strain["kwargs"])] += 1
valid_strains.remove(skipped_strain["strain"])

if args.output_log:
output_log_writer.writerow(skipped_strain)

if args.subsample_max_sequences and records_per_group is not None:
# Count the number of records per group. We will use this
# information to calculate the number of sequences per group
Expand All @@ -1344,11 +1426,17 @@ def run(args):
# Track the highest priority records, when we already
# know the number of sequences allowed per group.
if queues_by_group is None:
queues_by_group = defaultdict(priority_queue_factory(
max_size=sequences_per_group,
))
queues_by_group = {}

for strain, group in group_by_strain.items():
# During this first pass, we do not know all possible
# groups will be, so we need to build each group's queue
# as we first encounter the group.
if group not in queues_by_group:
queues_by_group[group] = PriorityQueue(
max_size=sequences_per_group,
)

queues_by_group[group].add(
metadata.loc[strain],
priorities[strain],
Expand Down Expand Up @@ -1408,9 +1496,12 @@ def run(args):
sys.exit(1)

if queues_by_group is None:
queues_by_group = defaultdict(priority_queue_factory(
max_size=sequences_per_group,
))
# We know all of the possible groups now from the first pass through
# the metadata, so we can create queues for all groups at once.
queues_by_group = create_queues_by_group(
records_per_group.keys(),
sequences_per_group,
)

# Make a second pass through the metadata, only considering records that
# have passed filters.
Expand All @@ -1425,7 +1516,7 @@ def run(args):
# during the first pass, but we want to minimize overall memory
# usage at the moment.
seq_keep = set(metadata.index.values) & valid_strains
group_by_strain = get_groups_for_subsampling(
group_by_strain, skipped_strains = get_groups_for_subsampling(
seq_keep,
metadata,
group_by,
Expand All @@ -1439,6 +1530,7 @@ def run(args):

# If we have any records in queues, we have grouped results and need to
# stream the highest priority records to the requested outputs.
num_excluded_subsamp = 0
if queues_by_group:
# Populate the set of strains to keep from the records in queues.
subsampled_strains = set()
Expand Down Expand Up @@ -1561,11 +1653,17 @@ def run(args):
"filter_by_date": "{count} of these were dropped because of their date (or lack of date)",
"filter_by_sequence_length": "{count} of these were dropped because they were shorter than minimum length of {min_length}bp",
"filter_by_non_nucleotide": "{count} of these were dropped because they had non-nucleotide characters",
"skip_group_by_with_ambiguous_year": "{count} were dropped during grouping due to ambiguous year information",
"skip_group_by_with_ambiguous_month": "{count} were dropped during grouping due to ambiguous month information",
"include": "{count} strains were added back because they were in {include_file}",
"include_by_include_where": "{count} sequences were added back because of '{include_where}'",
}
for (filter_name, filter_kwargs), count in filter_counts.items():
parameters = dict(json.loads(filter_kwargs))
if filter_kwargs:
parameters = dict(json.loads(filter_kwargs))
else:
parameters = {}

parameters["count"] = count
print("\t" + report_template_by_filter_name[filter_name].format(**parameters))

Expand Down
3 changes: 2 additions & 1 deletion augur/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ def read_metadata(metadata_file, id_columns=("strain", "name"), chunk_size=None)
...
(5, 14)
(5, 14)
(1, 14)
(2, 14)
"""
kwargs = {
"sep": None,
"engine": "python",
"skipinitialspace": True,
"na_filter": False,
}

if chunk_size:
Expand Down
Loading

0 comments on commit 49fa23e

Please sign in to comment.