Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelm committed Feb 1, 2024
1 parent ea62125 commit b25a123
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 99 deletions.
22 changes: 1 addition & 21 deletions src/cutadapt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import itertools
import multiprocessing
from pathlib import Path
from typing import Tuple, Optional, Sequence, List, Any, Iterator, Union, Dict, Iterable
from typing import Tuple, Optional, Sequence, List, Any, Iterator, Union, Dict
from argparse import ArgumentParser, SUPPRESS, HelpFormatter

import dnaio
Expand Down Expand Up @@ -871,26 +871,6 @@ def make_filter(
outfiles=outfiles,
)
steps.append(step)
"""
def _create_demultiplexer(self, outfiles) -> Union[PairedDemultiplexer, CombinatorialDemultiplexer]:
def open_writer(file, file2):
return self._open_writer(file, file2, force_fasta=outfiles.force_fasta)
if outfiles.combinatorial_out is not None:
assert outfiles.untrimmed is None and outfiles.untrimmed2 is None
writers = dict()
for key, out in outfiles.combinatorial_out.items():
writers[key] = open_writer(out, outfiles.combinatorial_out2[key])
return CombinatorialDemultiplexer(writers)
else:
writers = dict()
if outfiles.untrimmed is not None:
writers[None] = open_writer(outfiles.untrimmed, outfiles.untrimmed2)
for name, file in outfiles.demultiplex_out.items():
writers[name] = open_writer(file, outfiles.demultiplex_out2[name])
return PairedDemultiplexer(writers)
"""

else:
# When adapters are being trimmed only in R1 or R2, override the pair filter mode
# as using the default of 'any' would regard all read pairs as untrimmed.
Expand Down
21 changes: 13 additions & 8 deletions src/cutadapt/files.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import errno
import io
import sys
from abc import ABC
from abc import ABC, abstractmethod
from enum import Enum
from typing import BinaryIO, Optional, Dict, Tuple, List, TextIO, Union
from typing import BinaryIO, Optional, Dict, List, TextIO, Any

import dnaio
from xopen import xopen
Expand Down Expand Up @@ -143,6 +143,7 @@ def open(self) -> InputFiles:


class ProxyWriter(ABC):
@abstractmethod
def drain(self) -> List[bytes]:
pass

Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(
self._writers: Dict = {}
self._proxy_files: List[ProxyWriter] = []
self._proxied = proxied
self._to_close = []
self._to_close: List[BinaryIO] = []
self._qualities = qualities
self._interleaved = interleaved

Expand All @@ -235,7 +236,9 @@ def open_text(self, path):
def open_record_writer(
self, *paths, interleaved: bool = False, force_fasta: bool = False
):
kwargs = dict(qualities=self._qualities, interleaved=interleaved)
kwargs: Dict[str, Any] = dict(
qualities=self._qualities, interleaved=interleaved
)
if len(paths) not in (1, 2):
raise ValueError("Expected one or two paths")
if interleaved and len(paths) != 1:
Expand Down Expand Up @@ -265,7 +268,9 @@ def open_record_writer_from_binary_io(
self, file: BinaryIO, interleaved: bool = False, force_fasta: bool = False
):
self._binary_files.append(file)
kwargs = dict(qualities=self._qualities, interleaved=interleaved)
kwargs: Dict[str, Any] = dict(
qualities=self._qualities, interleaved=interleaved
)
if force_fasta and file is sys.stdout.buffer:
kwargs["fileformat"] = "fasta"
if self._proxied:
Expand All @@ -290,9 +295,9 @@ def close(self) -> None:
f.close()
for f in self._writers.values():
f.close()
for f in self._binary_files:
if f is not sys.stdout.buffer:
f.close()
for bf in self._binary_files:
if bf is not sys.stdout.buffer:
bf.close()


class FileFormat(Enum):
Expand Down
28 changes: 3 additions & 25 deletions src/cutadapt/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
List,
Optional,
Any,
Tuple,
Dict,
Union,
TextIO,
BinaryIO,
)
from typing import List, Optional, Any, Tuple, Union, TextIO, BinaryIO

import dnaio

Expand All @@ -22,21 +13,8 @@
PairedEndModifierWrapper,
ModificationInfo,
)
from .predicates import (
DiscardUntrimmed,
Predicate,
DiscardTrimmed,
)
from .steps import (
SingleEndSink,
PairedEndSink,
SingleEndFilter,
PairedEndFilter,
Demultiplexer,
PairedDemultiplexer,
CombinatorialDemultiplexer,
SingleEndStep,
)
from .predicates import DiscardUntrimmed, Predicate
from .steps import PairedEndSink, PairedEndFilter, SingleEndStep

logger = logging.getLogger()

Expand Down
74 changes: 29 additions & 45 deletions src/cutadapt/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
import itertools
from abc import ABC, abstractmethod
from typing import Tuple, Dict, Optional, Any, TextIO, Sequence, List
from typing import Tuple, Optional, Any, TextIO, Sequence, List

from dnaio import SequenceRecord

Expand Down Expand Up @@ -124,7 +124,6 @@ def __init__(
'both': The pair is discarded if both reads match.
'first': The pair is discarded if the first read matches.
"""
super().__init__()
if pair_filter_mode not in ("any", "both", "first"):
raise ValueError("pair_filter_mode must be 'any', 'both' or 'first'")
self._pair_filter_mode = pair_filter_mode
Expand Down Expand Up @@ -414,13 +413,13 @@ def __init__(
self,
adapter_names: Sequence[str],
template1: str,
template2: Optional[str],
template2: str,
untrimmed_output: Optional[str],
untrimmed_paired_output: Optional[str],
discard_untrimmed: bool,
outfiles: OutputFiles,
):
self._writers = self._open_writers(
self._writers, self._untrimmed_writer = self._open_writers(
adapter_names,
template1,
template2,
Expand All @@ -429,49 +428,39 @@ def __init__(
discard_untrimmed,
outfiles,
)
self._untrimmed_writer = self._writers.get(None, None)
self._statistics = ReadLengthStatistics()
self._filtered = 0

@staticmethod
def _open_writers(
adapter_names: Sequence[str],
template1: str,
template2: Optional[str],
template2: str,
untrimmed_output: Optional[str],
untrimmed_paired_output: Optional[str],
discard_untrimmed: bool,
outfiles: OutputFiles,
):
demultiplex_out = dict()
demultiplex_out2: Optional[Dict[str, Any]] = (
dict() if template2 is not None else None
)
for name in adapter_names:
path1 = template1.replace("{name}", name)
demultiplex_out[name] = file_opener.xopen(path1, "wb")
if demultiplex_out2 is not None:
assert template2 is not None
path2 = template2.replace("{name}", name)
demultiplex_out2[name] = file_opener.xopen(path2, "wb")
untrimmed_path: Optional[str] = template1.replace("{name}", "unknown")
if untrimmed_output:
untrimmed_path = untrimmed_output
path2 = template2.replace("{name}", name)
demultiplex_out[name] = outfiles.open_record_writer(path1, path2)

if discard_untrimmed:
untrimmed = None
else:
untrimmed = file_opener.xopen(untrimmed_path, "wb")
if template2 is not None:
untrimmed2_path = template2.replace("{name}", "unknown")
if untrimmed_paired_output:
untrimmed2_path = untrimmed_paired_output
if discard_untrimmed:
untrimmed2 = None
if untrimmed_output is not None:
untrimmed_path1 = untrimmed_output
else:
untrimmed2 = file_opener.xopen(untrimmed2_path, "wb")
else:
untrimmed2 = None
return demultiplex_out, demultiplex_out2, untrimmed, untrimmed2
untrimmed_path1 = template1.replace("{name}", "unknown")
if untrimmed_paired_output is not None:
untrimmed_path2 = untrimmed_paired_output
else:
untrimmed_path2 = template2.replace("{name}", "unknown")
untrimmed = outfiles.open_record_writer(untrimmed_path1, untrimmed_path2)

return demultiplex_out, untrimmed

def __call__(
self, read1, read2, info1: ModificationInfo, info2: ModificationInfo
Expand Down Expand Up @@ -509,8 +498,8 @@ def __init__(
self,
adapter_names,
adapter_names2,
output_template: str,
paired_output_template: str,
template1: str,
template2: str,
discard_untrimmed: bool,
outfiles: OutputFiles,
):
Expand All @@ -523,8 +512,8 @@ def __init__(
self._writers = self._open_writers(
adapter_names,
adapter_names2,
output_template,
paired_output_template,
template1,
template2,
discard_untrimmed,
outfiles,
)
Expand All @@ -534,13 +523,12 @@ def __init__(
def _open_writers(
adapter_names: Sequence[str],
adapter_names2: Sequence[str],
output_template: str,
paired_output_template: str,
template1: str,
template2: str,
discard_untrimmed: bool,
outfiles: OutputFiles,
):
combinatorial_out = dict()
combinatorial_out2 = dict()
writers = dict()
extra: List[Tuple[Optional[str], Optional[str]]]
if discard_untrimmed:
extra = []
Expand All @@ -553,15 +541,11 @@ def _open_writers(
): # type: ignore
fname1 = name1 if name1 is not None else "unknown"
fname2 = name2 if name2 is not None else "unknown"
path1 = output_template.replace("{name1}", fname1).replace(
"{name2}", fname2
)
path2 = paired_output_template.replace("{name1}", fname1).replace(
"{name2}", fname2
)
combinatorial_out[(name1, name2)] = file_opener.xopen(path1, "wb")
combinatorial_out2[(name1, name2)] = file_opener.xopen(path2, "wb")
return combinatorial_out, combinatorial_out2
path1 = template1.replace("{name1}", fname1).replace("{name2}", fname2)
path2 = template2.replace("{name1}", fname1).replace("{name2}", fname2)
writers[(name1, name2)] = outfiles.open_record_writer(path1, path2)

return writers

def __call__(self, read1, read2, info1, info2) -> Optional[RecordPair]:
"""
Expand Down

0 comments on commit b25a123

Please sign in to comment.