diff --git a/.devcontainer.json b/.devcontainer.json index d490e93c..7ccbd538 100644 --- a/.devcontainer.json +++ b/.devcontainer.json @@ -15,20 +15,31 @@ // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": "mamba env create -n haptools -f dev-env.yml && conda run -n haptools poetry install", + "postCreateCommand": "mamba env create -n haptools -f dev-env.yml && conda run -n haptools poetry config virtualenvs.in-project true && conda run -n haptools poetry install", // Configure tool-specific properties. "customizations": { "vscode": { - "extensions": ["ms-python.python"], + "extensions": [ + "ms-python.python", + "ms-python.black-formatter" + ], "settings": { + "python.analysis.typeCheckingMode": "off", // TODO: set to "strict" "python.condaPath": "/opt/conda/condabin/conda", - "python.defaultInterpreterPath": "/opt/conda/envs/haptools/bin/python", "python.terminal.activateEnvironment": true, "python.terminal.activateEnvInCurrentTerminal": true, - "python.venvFolders": ["/home/vscode/.cache/pypoetry/virtualenvs"], + "python.venvPath": "/workspaces/haptools/.venv", + "python.defaultInterpreterPath": "/workspaces/haptools/.venv/bin/python", + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, "terminal.integrated.environmentChangesRelaunch": true, - "terminal.integrated.hideOnStartup": "always" + "editor.defaultFormatter": "ms-python.black-formatter", + "terminal.integrated.hideOnStartup": "always", + "files.eol": "\n" } } } diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 395f6b35..20cabc38 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,6 +39,7 @@ jobs: auto-activate-base: false miniforge-version: latest use-mamba: true + conda-remove-defaults: "true" - name: Get Date id: get-date diff --git a/haptools/clump.py b/haptools/clump.py index 2400b56c..da1a42ed 100644 --- a/haptools/clump.py +++ b/haptools/clump.py @@ -6,7 +6,7 @@ import numpy as np -from .data import Genotypes, GenotypesVCF, GenotypesTR, GenotypesPLINKTR +from .data import Genotypes, GenotypesVCF, GenotypesPLINK, GenotypesTR, GenotypesPLINKTR class Variant: diff --git a/haptools/data/breakpoints.py b/haptools/data/breakpoints.py index 6610ca46..e95aaae5 100644 --- a/haptools/data/breakpoints.py +++ b/haptools/data/breakpoints.py @@ -23,7 +23,7 @@ # This tuple lists the haplotype blocks in a sample, one set for each chromosome # Let's define a type alias, "SampleBlocks", for future use... SampleBlocks = NewType( - "SampleBlocks", "list[npt.NDArray[HapBlock], npt.NDArray[HapBlock]]]" + "SampleBlocks", "list[npt.NDArray[HapBlock], npt.NDArray[HapBlock]]]" # type: ignore ) diff --git a/haptools/data/genotypes.py b/haptools/data/genotypes.py index 953d8210..4d812d10 100644 --- a/haptools/data/genotypes.py +++ b/haptools/data/genotypes.py @@ -14,7 +14,7 @@ from cyvcf2 import VCF, Variant try: - import trtools.utils.tr_harmonizer as trh + import trtools.utils.tr_harmonizer as trh # type: ignore except ModuleNotFoundError: from . import tr_harmonizer as trh @@ -219,7 +219,7 @@ def _variant_arr(self, record: Variant): Parameters ---------- record: Variant - A cyvcf2.Variant object from which to fetch metadata + A Variant object from which to fetch metadata Returns ------- @@ -231,20 +231,20 @@ def _variant_arr(self, record: Variant): dtype=self.variants.dtype, ) - def _vcf_iter(self, vcf: cyvcf2.VCF, region: str): + def _vcf_iter(self, vcf: VCF, region: str): """ Yield all variants within a region in the VCF file. Parameters ---------- vcf: VCF - The cyvcf2.VCF object from which to fetch variant records + The VCF object from which to fetch variant records region : str, optional See documentation for :py:meth:`~.Genotypes.read` Returns ------- - vcffile : cyvcf2.VCF + vcffile : VCF Iterable cyvcf2 instance. """ return vcf(region) @@ -255,8 +255,8 @@ def _return_data(self, variant: Variant): Parameters ---------- - variant: cyvcf2.Variant - A cyvcf2.Variant object from which to fetch genotypes + variant: Variant + A Variant object from which to fetch genotypes Returns ------- @@ -274,7 +274,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None): Parameters ---------- vcf: VCF - The cyvcf2.VCF object from which to fetch variant records + The VCF object from which to fetch variant records region : str, optional See documentation for :py:meth:`~.Genotypes.read` variants : set[str], optional @@ -805,7 +805,13 @@ def write(self): record.samples[sample].phased = self.data[samp_idx, var_idx, 2] # write the record to a file vcf.write(record) - vcf.close() + try: + vcf.close() + except OSError as e: + if e.errno == 9 and len(self.variants) == 0: + self.log.warning(f"No variants in {self.fname}.") + else: + raise e class TRRecordHarmonizerRegion(trh.TRRecordHarmonizer): @@ -909,14 +915,14 @@ def load( genotypes.check_phase() return genotypes - def _vcf_iter(self, vcf: cyvcf2.VCF, region: str = None): + def _vcf_iter(self, vcf: VCF, region: str = None): """ Collect GTs (trh.TRRecord objects) to iterate over Parameters ---------- vcf: VCF - The cyvcf2.VCF object from which to fetch variant records + The VCF object from which to fetch variant records region : str, optional See documentation for :py:meth:`~.Genotypes.read` @@ -1066,7 +1072,7 @@ def read_samples(self, samples: set[str] = None): self.samples = { ct: samp[col_idx] for ct, samp in enumerate(psamples) - if (samples is None) or (samp[col_idx] in samples) + if len(samp) and ((samples is None) or (samp[col_idx] in samples)) } indices = np.array(list(self.samples.keys()), dtype=np.uint32) self.samples = tuple(self.samples.values()) @@ -1289,7 +1295,18 @@ def read( super(Genotypes, self).read() sample_idxs = self.read_samples(samples) - pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8")) + pvar_fname = bytes(str(self.fname.with_suffix(".pvar")), "utf8") + try: + pv = pgenlib.PvarReader(pvar_fname) + except RuntimeError as e: + if e.args[0].decode("utf8").startswith("No variants in"): + self.log.warning(f"No variants in {pvar_fname}.") + self.data = np.empty( + (len(sample_idxs), 0, (2 + (not self._prephased))), dtype=np.uint8 + ) + return + else: + raise e with pgenlib.PgenReader( bytes(str(self.fname), "utf8"), sample_subset=sample_idxs, pvar=pv @@ -1544,12 +1561,23 @@ def write(self): chunks = len(self.variants) # write the pgen file - pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8")) + try: + max_allele_ct = pgenlib.PvarReader( + bytes(str(self.fname.with_suffix(".pvar")), "utf8") + ).get_max_allele_ct() + except RuntimeError as e: + if len(self.variants) == 0: + # write an empty pgen file + with open(self.fname, "wb"): + pass + return + else: + raise e with pgenlib.PgenWriter( filename=bytes(str(self.fname), "utf8"), sample_ct=len(self.samples), variant_ct=len(self.variants), - allele_ct_limit=pv.get_max_allele_ct(), + allele_ct_limit=max_allele_ct, nonref_flags=False, hardcall_phase_present=True, ) as pgen: diff --git a/haptools/data/haplotypes.py b/haptools/data/haplotypes.py index f70e4709..5923e828 100644 --- a/haptools/data/haplotypes.py +++ b/haptools/data/haplotypes.py @@ -461,7 +461,7 @@ def extras_order(cls) -> tuple[str]: """ return tuple(extra.name for extra in cls._extras) - def transform(self, genotypes: GenotypesVCF) -> npt.NDArray[bool]: + def transform(self, genotypes: GenotypesVCF) -> npt.NDArray: """ Transform a genotypes matrix via the current haplotype @@ -478,9 +478,9 @@ def transform(self, genotypes: GenotypesVCF) -> npt.NDArray[bool]: Returns ------- - npt.NDArray[bool] - A 2D matrix of shape (num_samples, 2) where each entry in the matrix - denotes the presence of the haplotype in one chromosome of a sample + npt.NDArray + A 2D matrix of shape (num_samples, 2) where each entry in the matrix is a + bool denoting the presence of the haplotype in one chromosome of a sample """ var_IDs = self.varIDs # ensure the variants in the Genotypes object are ordered according to var_IDs @@ -1198,12 +1198,16 @@ def __iter__( indexed = True try: haps_file = TabixFile(str(self.fname)) + if region is not None: + haps_file.fetch(region=region, multiple_iterators=True) except OSError: indexed = False - # if the user requested a specific region or subset of haplotypes and the file + except ValueError: + indexed = False + # If the user requested a specific region or subset of haplotypes and the file # is indexed, then we should handle it using tabix # else, we use a regular text opener - b/c there's no benefit to using tabix - if region or (haplotypes and indexed): + if (region or haplotypes) and indexed: haps_file = TabixFile(str(self.fname)) metas, extras = self.check_header(list(haps_file.header)) types = self._get_field_types(extras, metas.get("order")) @@ -1232,8 +1236,8 @@ def __iter__( ) haps_file.close() else: - # the file is not indexed, so we can't assume it's sorted, either - # use hook_compressed to automatically handle gz files + # The file is not indexed, so we can't assume it's sorted, either + # Use hook_compressed to automatically handle gz files with self.hook_compressed(self.fname, mode="r") as haps: self.log.info("Not taking advantage of indexing.") header_lines = [] diff --git a/haptools/data/tr_harmonizer.py b/haptools/data/tr_harmonizer.py index 9813d901..84015d9f 100644 --- a/haptools/data/tr_harmonizer.py +++ b/haptools/data/tr_harmonizer.py @@ -3,6 +3,7 @@ Handles VCFs generated by various TR genotyping tools """ import re +import math import enum import warnings from typing import ( diff --git a/haptools/transform.py b/haptools/transform.py index 4e406a39..f43ccb84 100644 --- a/haptools/transform.py +++ b/haptools/transform.py @@ -593,6 +593,8 @@ def transform_haps( "that the IDs in your .hap file correspond with those you provided. " f"Here are the first few missing haplotypes: {diff[:first_few]}" ) + if len(hp.data) == 0: + raise ValueError("Didn't load any haplotypes from the .hap file") log.info("Extracting variants from haplotypes") variants = {vr.id for id in hp.type_ids["H"] for vr in hp.data[id].variants} diff --git a/noxfile.py b/noxfile.py index eec3d2d4..14fbda26 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,9 +3,9 @@ import shutil from pathlib import Path -import nox -from nox_poetry import Session -from nox_poetry import session +import nox # type: ignore +from nox_poetry import Session # type: ignore +from nox_poetry import session # type: ignore package = "haptools" diff --git a/tests/bench_transform.py b/tests/bench_transform.py index 3f3f3d3f..6e5bce08 100755 --- a/tests/bench_transform.py +++ b/tests/bench_transform.py @@ -7,6 +7,7 @@ from datetime import datetime import click +import matplotlib import numpy as np import matplotlib.pyplot as plt diff --git a/tests/test_data.py b/tests/test_data.py index 2e0057c2..36377ce3 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -569,6 +569,20 @@ def test_write_genotypes(self): fname.with_suffix(".pvar").unlink() fname.unlink() + def test_write_genotypes_empty(self): + fname = DATADIR / "test_write.pgen" + gts = GenotypesPLINK(fname=fname) + gts.data = np.empty((0, 0, 0), dtype=np.uint8) + gts.samples = () + gts.variants = np.empty(0, dtype=gts.variants.dtype) + gts.write() + gts.read() + + # clean up afterwards: delete the files we created + fname.with_suffix(".psam").unlink() + fname.with_suffix(".pvar").unlink() + fname.unlink() + def test_write_genotypes_prephased(self): gts = self._get_fake_genotypes_plink() @@ -1851,6 +1865,16 @@ def test_write_ref_alt(self, multiallelic=False): fname.unlink() + def test_write_empty(self): + fname = Path("test.vcf") + gts = GenotypesVCF(fname=fname) + gts.samples = () + gts.variants = np.array([], dtype=gts.variants.dtype) + gts.data = np.empty((0, 0, 0), dtype=np.uint8) + gts.write() + gts.read() + fname.unlink() + def test_write_multiallelic(self): self.test_write_ref_alt(multiallelic=True) diff --git a/tests/test_transform.py b/tests/test_transform.py index 10f1c8f4..c094e720 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -371,3 +371,42 @@ def test_ancestry_from_bp(capfd): captured = capfd.readouterr() assert captured.out == ancestry_results assert result.exit_code == 0 + + +def test_transform_empty_hap(capfd): + gt_file = DATADIR / "simple.vcf.gz" + hp_file = Path("empty.hap") + hp_file_gz = Path("empty.hap.gz") + hp_file_idx = Path("empty.hap.gz.tbi") + + # create an empty .hap file + with open(hp_file, "w") as f: + f.write("") + + # can we run transform with the empty hap file? + cmd = f"transform --region 1:10116-10122 {gt_file} {hp_file}" + runner = CliRunner() + result = runner.invoke(main, cmd.split(" ")) + captured = capfd.readouterr() + assert all(line for line in captured.out.split("\n") if line.startswith("#")) + assert result.exit_code != 0 + + # now, index the empty hap file and try again + cmd = f"index {hp_file}" + runner = CliRunner() + result = runner.invoke(main, cmd.split(" "), catch_exceptions=False) + captured = capfd.readouterr() + assert result.exit_code == 0 + assert hp_file_gz.exists() + assert hp_file_idx.exists() + + # what about now? does it still fail? + cmd = f"transform --region 1:10116-10122 {gt_file} {hp_file_gz}" + runner = CliRunner() + result = runner.invoke(main, cmd.split(" ")) + captured = capfd.readouterr() + assert result.exit_code != 0 + + hp_file.unlink() + hp_file_gz.unlink() + hp_file_idx.unlink()