Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 1) improve error in transform when providing an empty hap file and a --region and 2) allow for calling write() on Genotypes objects without variants #264

Merged
merged 9 commits into from
Dec 11, 2024
21 changes: 16 additions & 5 deletions .devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,31 @@
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "mamba env create -n haptools -f dev-env.yml && conda run -n haptools poetry install",
"postCreateCommand": "mamba env create -n haptools -f dev-env.yml && conda run -n haptools poetry config virtualenvs.in-project true && conda run -n haptools poetry install",

// Configure tool-specific properties.
"customizations": {
"vscode": {
"extensions": ["ms-python.python"],
"extensions": [
"ms-python.python",
"ms-python.black-formatter"
],
"settings": {
"python.analysis.typeCheckingMode": "off", // TODO: set to "strict"
"python.condaPath": "/opt/conda/condabin/conda",
"python.defaultInterpreterPath": "/opt/conda/envs/haptools/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"python.venvFolders": ["/home/vscode/.cache/pypoetry/virtualenvs"],
"python.venvPath": "/workspaces/haptools/.venv",
"python.defaultInterpreterPath": "/workspaces/haptools/.venv/bin/python",
"python.testing.pytestArgs": [
"tests"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"terminal.integrated.environmentChangesRelaunch": true,
"terminal.integrated.hideOnStartup": "always"
"editor.defaultFormatter": "ms-python.black-formatter",
"terminal.integrated.hideOnStartup": "always",
"files.eol": "\n"
}
}
}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
auto-activate-base: false
miniforge-version: latest
use-mamba: true
conda-remove-defaults: "true"

- name: Get Date
id: get-date
Expand Down
2 changes: 1 addition & 1 deletion haptools/clump.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .data import Genotypes, GenotypesVCF, GenotypesTR, GenotypesPLINKTR
from .data import Genotypes, GenotypesVCF, GenotypesPLINK, GenotypesTR, GenotypesPLINKTR


class Variant:
Expand Down
2 changes: 1 addition & 1 deletion haptools/data/breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# This tuple lists the haplotype blocks in a sample, one set for each chromosome
# Let's define a type alias, "SampleBlocks", for future use...
SampleBlocks = NewType(
"SampleBlocks", "list[npt.NDArray[HapBlock], npt.NDArray[HapBlock]]]"
"SampleBlocks", "list[npt.NDArray[HapBlock], npt.NDArray[HapBlock]]]" # type: ignore
)


Expand Down
58 changes: 43 additions & 15 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cyvcf2 import VCF, Variant

try:
import trtools.utils.tr_harmonizer as trh
import trtools.utils.tr_harmonizer as trh # type: ignore
except ModuleNotFoundError:
from . import tr_harmonizer as trh

Expand Down Expand Up @@ -219,7 +219,7 @@ def _variant_arr(self, record: Variant):
Parameters
----------
record: Variant
A cyvcf2.Variant object from which to fetch metadata
A Variant object from which to fetch metadata

Returns
-------
Expand All @@ -231,20 +231,20 @@ def _variant_arr(self, record: Variant):
dtype=self.variants.dtype,
)

def _vcf_iter(self, vcf: cyvcf2.VCF, region: str):
def _vcf_iter(self, vcf: VCF, region: str):
"""
Yield all variants within a region in the VCF file.

Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
The VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`

Returns
-------
vcffile : cyvcf2.VCF
vcffile : VCF
Iterable cyvcf2 instance.
"""
return vcf(region)
Expand All @@ -255,8 +255,8 @@ def _return_data(self, variant: Variant):

Parameters
----------
variant: cyvcf2.Variant
A cyvcf2.Variant object from which to fetch genotypes
variant: Variant
A Variant object from which to fetch genotypes

Returns
-------
Expand All @@ -274,7 +274,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None):
Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
The VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
Expand Down Expand Up @@ -805,7 +805,13 @@ def write(self):
record.samples[sample].phased = self.data[samp_idx, var_idx, 2]
# write the record to a file
vcf.write(record)
vcf.close()
try:
vcf.close()
except OSError as e:
if e.errno == 9 and len(self.variants) == 0:
self.log.warning(f"No variants in {self.fname}.")
else:
raise e


class TRRecordHarmonizerRegion(trh.TRRecordHarmonizer):
Expand Down Expand Up @@ -909,14 +915,14 @@ def load(
genotypes.check_phase()
return genotypes

def _vcf_iter(self, vcf: cyvcf2.VCF, region: str = None):
def _vcf_iter(self, vcf: VCF, region: str = None):
"""
Collect GTs (trh.TRRecord objects) to iterate over

Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
The VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`

Expand Down Expand Up @@ -1066,7 +1072,7 @@ def read_samples(self, samples: set[str] = None):
self.samples = {
ct: samp[col_idx]
for ct, samp in enumerate(psamples)
if (samples is None) or (samp[col_idx] in samples)
if len(samp) and ((samples is None) or (samp[col_idx] in samples))
}
indices = np.array(list(self.samples.keys()), dtype=np.uint32)
self.samples = tuple(self.samples.values())
Expand Down Expand Up @@ -1289,7 +1295,18 @@ def read(
super(Genotypes, self).read()

sample_idxs = self.read_samples(samples)
pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8"))
pvar_fname = bytes(str(self.fname.with_suffix(".pvar")), "utf8")
try:
pv = pgenlib.PvarReader(pvar_fname)
except RuntimeError as e:
if e.args[0].decode("utf8").startswith("No variants in"):
self.log.warning(f"No variants in {pvar_fname}.")
self.data = np.empty(
(len(sample_idxs), 0, (2 + (not self._prephased))), dtype=np.uint8
)
return
else:
raise e

with pgenlib.PgenReader(
bytes(str(self.fname), "utf8"), sample_subset=sample_idxs, pvar=pv
Expand Down Expand Up @@ -1544,12 +1561,23 @@ def write(self):
chunks = len(self.variants)

# write the pgen file
pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8"))
try:
max_allele_ct = pgenlib.PvarReader(
bytes(str(self.fname.with_suffix(".pvar")), "utf8")
).get_max_allele_ct()
except RuntimeError as e:
if len(self.variants) == 0:
# write an empty pgen file
with open(self.fname, "wb"):
pass
return
else:
raise e
with pgenlib.PgenWriter(
filename=bytes(str(self.fname), "utf8"),
sample_ct=len(self.samples),
variant_ct=len(self.variants),
allele_ct_limit=pv.get_max_allele_ct(),
allele_ct_limit=max_allele_ct,
nonref_flags=False,
hardcall_phase_present=True,
) as pgen:
Expand Down
20 changes: 12 additions & 8 deletions haptools/data/haplotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def extras_order(cls) -> tuple[str]:
"""
return tuple(extra.name for extra in cls._extras)

def transform(self, genotypes: GenotypesVCF) -> npt.NDArray[bool]:
def transform(self, genotypes: GenotypesVCF) -> npt.NDArray:
"""
Transform a genotypes matrix via the current haplotype

Expand All @@ -478,9 +478,9 @@ def transform(self, genotypes: GenotypesVCF) -> npt.NDArray[bool]:

Returns
-------
npt.NDArray[bool]
A 2D matrix of shape (num_samples, 2) where each entry in the matrix
denotes the presence of the haplotype in one chromosome of a sample
npt.NDArray
A 2D matrix of shape (num_samples, 2) where each entry in the matrix is a
bool denoting the presence of the haplotype in one chromosome of a sample
"""
var_IDs = self.varIDs
# ensure the variants in the Genotypes object are ordered according to var_IDs
Expand Down Expand Up @@ -1198,12 +1198,16 @@ def __iter__(
indexed = True
try:
haps_file = TabixFile(str(self.fname))
if region is not None:
haps_file.fetch(region=region, multiple_iterators=True)
except OSError:
indexed = False
# if the user requested a specific region or subset of haplotypes and the file
except ValueError:
indexed = False
# If the user requested a specific region or subset of haplotypes and the file
# is indexed, then we should handle it using tabix
# else, we use a regular text opener - b/c there's no benefit to using tabix
if region or (haplotypes and indexed):
if (region or haplotypes) and indexed:
haps_file = TabixFile(str(self.fname))
metas, extras = self.check_header(list(haps_file.header))
types = self._get_field_types(extras, metas.get("order"))
Expand Down Expand Up @@ -1232,8 +1236,8 @@ def __iter__(
)
haps_file.close()
else:
# the file is not indexed, so we can't assume it's sorted, either
# use hook_compressed to automatically handle gz files
# The file is not indexed, so we can't assume it's sorted, either
# Use hook_compressed to automatically handle gz files
with self.hook_compressed(self.fname, mode="r") as haps:
self.log.info("Not taking advantage of indexing.")
header_lines = []
Expand Down
1 change: 1 addition & 0 deletions haptools/data/tr_harmonizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Handles VCFs generated by various TR genotyping tools
"""
import re
import math
import enum
import warnings
from typing import (
Expand Down
2 changes: 2 additions & 0 deletions haptools/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ def transform_haps(
"that the IDs in your .hap file correspond with those you provided. "
f"Here are the first few missing haplotypes: {diff[:first_few]}"
)
if len(hp.data) == 0:
raise ValueError("Didn't load any haplotypes from the .hap file")

log.info("Extracting variants from haplotypes")
variants = {vr.id for id in hp.type_ids["H"] for vr in hp.data[id].variants}
Expand Down
6 changes: 3 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import shutil
from pathlib import Path

import nox
from nox_poetry import Session
from nox_poetry import session
import nox # type: ignore
from nox_poetry import Session # type: ignore
from nox_poetry import session # type: ignore


package = "haptools"
Expand Down
1 change: 1 addition & 0 deletions tests/bench_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime

import click
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

Expand Down
24 changes: 24 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,20 @@ def test_write_genotypes(self):
fname.with_suffix(".pvar").unlink()
fname.unlink()

def test_write_genotypes_empty(self):
fname = DATADIR / "test_write.pgen"
gts = GenotypesPLINK(fname=fname)
gts.data = np.empty((0, 0, 0), dtype=np.uint8)
gts.samples = ()
gts.variants = np.empty(0, dtype=gts.variants.dtype)
gts.write()
gts.read()

# clean up afterwards: delete the files we created
fname.with_suffix(".psam").unlink()
fname.with_suffix(".pvar").unlink()
fname.unlink()

def test_write_genotypes_prephased(self):
gts = self._get_fake_genotypes_plink()

Expand Down Expand Up @@ -1851,6 +1865,16 @@ def test_write_ref_alt(self, multiallelic=False):

fname.unlink()

def test_write_empty(self):
fname = Path("test.vcf")
gts = GenotypesVCF(fname=fname)
gts.samples = ()
gts.variants = np.array([], dtype=gts.variants.dtype)
gts.data = np.empty((0, 0, 0), dtype=np.uint8)
gts.write()
gts.read()
fname.unlink()

def test_write_multiallelic(self):
self.test_write_ref_alt(multiallelic=True)

Expand Down
39 changes: 39 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,42 @@ def test_ancestry_from_bp(capfd):
captured = capfd.readouterr()
assert captured.out == ancestry_results
assert result.exit_code == 0


def test_transform_empty_hap(capfd):
gt_file = DATADIR / "simple.vcf.gz"
hp_file = Path("empty.hap")
hp_file_gz = Path("empty.hap.gz")
hp_file_idx = Path("empty.hap.gz.tbi")

# create an empty .hap file
with open(hp_file, "w") as f:
f.write("")

# can we run transform with the empty hap file?
cmd = f"transform --region 1:10116-10122 {gt_file} {hp_file}"
runner = CliRunner()
result = runner.invoke(main, cmd.split(" "))
captured = capfd.readouterr()
assert all(line for line in captured.out.split("\n") if line.startswith("#"))
assert result.exit_code != 0

# now, index the empty hap file and try again
cmd = f"index {hp_file}"
runner = CliRunner()
result = runner.invoke(main, cmd.split(" "), catch_exceptions=False)
captured = capfd.readouterr()
assert result.exit_code == 0
assert hp_file_gz.exists()
assert hp_file_idx.exists()

# what about now? does it still fail?
cmd = f"transform --region 1:10116-10122 {gt_file} {hp_file_gz}"
runner = CliRunner()
result = runner.invoke(main, cmd.split(" "))
captured = capfd.readouterr()
assert result.exit_code != 0

hp_file.unlink()
hp_file_gz.unlink()
hp_file_idx.unlink()
Loading