Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor filtering of user specified strains #178

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 11 additions & 97 deletions src/nplinker/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nplinker.pairedomics.runbigscape import run_bigscape
from nplinker.pairedomics.strain_mappings_generator import podp_generate_strain_mappings
from nplinker.strain_collection import StrainCollection
from nplinker.strain_loader import load_user_strains
from nplinker.strains import Strain


Expand Down Expand Up @@ -205,10 +206,6 @@ def load(self):
# TODO add a config file option for this?
self._filter_only_common_strains()

# if the user specified a set of strains to be explicitly included, filter
# out everything except those strains
self._filter_user_strains()

# if we don't have at least *some* strains here it probably means missing mappings
# or a complete failure to parse things, so bail out
if len(self.strains) == 0:
Expand Down Expand Up @@ -382,13 +379,21 @@ def _load_mibig(self):
return True

def _load_strain_mappings(self):
# First load user's strain mappings
# 1. load strain mappings
sc = StrainCollection.read_json(self.strain_mappings_file)
for strain in sc:
self.strains.add(strain)
logger.info("Loaded {} non-MiBIG Strain objects".format(len(self.strains)))

# Then load MiBIG strain mappings
# 2. filter user specificied strains (remove all that are not specified by user).
# It's not allowed to specify empty list of strains, otherwise validation will fail.
if os.path.exists(self.include_strains_file):
logger.info(f"Loading user specified strains from file {self.include_strains_file}.")
user_strains = load_user_strains(self.include_strains_file)
logger.info(f"Loaded {len(user_strains)} user specified strains.")
self.strains.filter(user_strains)

# 3. load MiBIG strain mappings
if self._mibig_strain_bgc_mapping:
for k, v in self._mibig_strain_bgc_mapping.items():
strain = Strain(k)
Expand Down Expand Up @@ -573,28 +578,6 @@ def _load_optional(self):
self.description_text = open(self.description_file).read()
logger.debug("Parsed description text")

self.include_only_strains = set()
if os.path.exists(self.include_strains_file):
logger.debug("Loading include_strains from {}".format(self.include_strains_file))
strain_list = open(self.include_strains_file).readlines()
self.include_only_strains = StrainCollection()
for line_num, sid in enumerate(strain_list):
sid = sid.strip() # get rid of newline
try:
strain_ref_list = self.strains.lookup(sid)
except KeyError:
logger.warning(
'Line {} of {}: invalid/unknown strain ID "{}"'.format(
line_num + 1, self.include_strains_file, sid
)
)
continue
for strain in strain_ref_list:
self.include_only_strains.add(strain)
logger.debug(
"Found {} strain IDs in include_strains".format(len(self.include_only_strains))
)

def _filter_only_common_strains(self):
"""Filter strain population to only strains present in both genomic and molecular data."""
# TODO: Maybe there should be an option to specify which strains are used, both so we can
Expand Down Expand Up @@ -627,75 +610,6 @@ def _filter_only_common_strains(self):
spec.strains.filter(common_strains)
logger.info("Strains filtered down to total of {}".format(len(self.strains)))

def _filter_user_strains(self):
"""If the user has supplied a list of strains to be explicitly included, go through the
existing sets of objects we have and remove any that only include other strains. This
involves an initial round of removing BGC and Spectrum objects, then a further round
of removing now-empty GCF and MolFam objects.
"""
if len(self.include_only_strains) == 0:
logger.info("No further strain filtering to apply")
return

logger.info(
"Found a list of {} strains to retain, filtering objects".format(
len(self.include_only_strains)
)
)

# filter the main list of strains
self.strains.filter(self.include_only_strains)

if len(self.strains) == 0:
logger.error("Strain list has been filtered down until it is empty! ")
logger.error(
"This probably indicates that you tried to specifically include a set of strains that had no overlap with the set common to metabolomics and genomics data (see the common_strains.csv in the dataset folder for a list of these"
)
raise Exception("No strains left after filtering, cannot continue!")

# get the list of BGCs which have a strain found in the set we were given
bgcs_to_retain = {bgc for bgc in self.bgcs if bgc.strain in self.include_only_strains}
# get the list of spectra which have at least one strain in the set
spectra_to_retain = {
spec
for spec in self.spectra
for sstrain in spec.strains
if sstrain in self.include_only_strains
}

logger.info(
"Current / filtered BGC counts: {} / {}".format(len(self.bgcs), len(bgcs_to_retain))
)
logger.info(
"Current / filtered spectra counts: {} / {}".format(
len(self.spectra), len(spectra_to_retain)
)
)

self.bgcs = list(bgcs_to_retain)

self.spectra = list(spectra_to_retain)
# also need to filter the set of strains attached to each spectrum
for i, spec in enumerate(self.spectra):
spec.strains.filter(self.include_only_strains)
spec.id = i

# now filter GCFs and MolFams based on the filtered BGCs and Spectra
gcfs = {parent for bgc in self.bgcs for parent in bgc.parents}
logger.info("Current / filtered GCF counts: {} / {}".format(len(self.gcfs), len(gcfs)))
self.gcfs = list(gcfs)
# filter each GCF's strain list
for gcf in self.gcfs:
gcf.strains.filter(self.include_only_strains)

molfams = {spec.family for spec in self.spectra}
logger.info(
"Current / filtered MolFam counts: {} / {}".format(len(self.molfams), len(molfams))
)
self.molfams = list(molfams)
for i, molfam in enumerate(self.molfams):
molfam.id = i


def find_via_glob(path, file_type, optional=False):
try:
Expand Down
4 changes: 4 additions & 0 deletions src/nplinker/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"GENOME_BGC_MAPPINGS_SCHEMA",
"STRAIN_MAPPINGS_SCHEMA",
"PODP_ADAPTED_SCHEMA",
"USER_STRAINS_SCHEMA",
"validate_podp_json",
]

Expand All @@ -24,3 +25,6 @@

with open(SCHEMA_DIR / "strain_mappings_schema.json", "r") as f:
STRAIN_MAPPINGS_SCHEMA = json.load(f)

with open(SCHEMA_DIR / "user_strains.json", "r") as f:
USER_STRAINS_SCHEMA = json.load(f)
30 changes: 30 additions & 0 deletions src/nplinker/schemas/user_strains.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://raw.githubusercontent.com/NPLinker/nplinker/main/src/nplinker/schemas/user_strains.json",
"title": "User specificed strains",
"description": "A list of strain IDs specified by user",
"type": "object",
"required": [
"strain_ids"
],
"properties": {
"strain_ids": {
"type": "array",
"title": "Strain IDs",
"description": "A list of strain IDs specificed by user. The strain IDs must be the same as the ones in the strain mappings file.",
"items": {
"type": "string",
"minLength": 1
},
"minItems": 1,
"uniqueItems": true
},
"version": {
"type": "string",
"enum": [
"1.0"
]
}
},
"additionalProperties": false
}
6 changes: 5 additions & 1 deletion src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def remove(self, strain: Strain):
raise ValueError(f"Strain {strain} not found in strain collection.")

def filter(self, strain_set: set[Strain]):
"""Remove all strains that are not in strain_set from the strain collection."""
"""Remove all strains that are not in strain_set from the strain collection.

Args:
strain_set(set[Strain]): Set of strains to keep.
"""
# note that we need to copy the list of strains, as we are modifying it
for strain in self._strains.copy():
if strain not in strain_set:
Expand Down
35 changes: 35 additions & 0 deletions src/nplinker/strain_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
from os import PathLike
from jsonschema import validate
from nplinker.logconfig import LogConfig
from nplinker.schemas import USER_STRAINS_SCHEMA
from .strains import Strain


logger = LogConfig.getLogger(__name__)


def load_user_strains(json_file: str | PathLike) -> set[Strain]:
"""Load user specified strains from a JSON file.

The JSON file must follow the schema defined in "nplinker/schemas/user_strains_schema.json".
CunliangGeng marked this conversation as resolved.
Show resolved Hide resolved
An example content of the JSON file:
{"strain_ids": ["strain1", "strain2"]}

Args:
json_file(str | PathLike): Path to the JSON file containing user specified strains.

Returns:
set[Strain]: A set of user specified strains.
"""
with open(json_file, "r") as f:
json_data = json.load(f)

# validate json data
validate(instance=json_data, schema=USER_STRAINS_SCHEMA)

strains = set()
for strain_id in json_data["strain_ids"]:
strains.add(Strain(strain_id))

return strains
49 changes: 49 additions & 0 deletions tests/schemas/test_user_strains_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from jsonschema import validate
from jsonschema.exceptions import ValidationError
from nplinker.schemas import USER_STRAINS_SCHEMA


# Test schema aginast invalid data
data_no_strain_ids = {"version": "1.0"}
data_empty_strain_ids = {"strain_ids": [], "version": "1.0"}
data_invalid_strain_ids = {"strain_ids": [1, 2, 3], "version": "1.0"}
data_empty_version = {"strain_ids": ["strain1", "strain2"], "version": ""}
data_invalid_version = {"strain_ids": ["strain1", "strain2"], "version": "1.0.0"}


@pytest.mark.parametrize(
"data, expected",
[
[data_no_strain_ids, "'strain_ids' is a required property"],
[data_empty_strain_ids, "[] is too short"],
[data_invalid_strain_ids, "1 is not of type 'string'"],
[data_empty_version, "'' is not one of ['1.0']"],
[data_invalid_version, "'1.0.0' is not one of ['1.0']"],
],
)
def test_invalid_data(data, expected):
"""Test user strains schema against invalid data."""
with pytest.raises(ValidationError) as e:
validate(data, USER_STRAINS_SCHEMA)
assert e.value.message == expected


# Test schema aginast valid data
data = {"strain_ids": ["strain1", "strain2"], "version": "1.0"}
data_no_version = {"strain_ids": ["strain1", "strain2"]}


@pytest.mark.parametrize(
"data",
[
data,
data_no_version,
],
)
def test_valid_data(data):
"""Test user strains schema against valid data."""
try:
validate(data, USER_STRAINS_SCHEMA)
except ValidationError:
pytest.fail("Unexpected ValidationError")
106 changes: 0 additions & 106 deletions tests/test_loader.py

This file was deleted.

Loading
Loading