Skip to content

Commit

Permalink
Introduce ProxyTextFile
Browse files Browse the repository at this point in the history
Steps that write to output files cannot just write to normal file-like
objects because we need to switch to BytesIO-backed files when running in
parallel. The ProxyTextFile acts like a normal file, but is backed by a
BytesIO object.

The plan is to make it possible to create a Pipeline just once and not
have to recreate in each worker.
  • Loading branch information
marcelm committed Jan 26, 2024
1 parent 1fff8cd commit b6e256c
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 54 deletions.
37 changes: 29 additions & 8 deletions src/cutadapt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from cutadapt.pipeline import SingleEndPipeline, PairedEndPipeline
from cutadapt.runners import Pipeline, run_pipeline
from cutadapt.files import InputPaths, OutputFiles, FileOpener, OutputPaths
from cutadapt.steps import InfoFileWriter, PairedSingleEndStep
from cutadapt.utils import available_cpu_count, Progress, DummyProgress
from cutadapt.log import setup_logging, REPORT
from cutadapt.qualtrim import HasNoQualities
Expand Down Expand Up @@ -444,6 +445,7 @@ def open_output_files(
file_opener: FileOpener,
adapter_names: Sequence[Optional[str]],
adapter_names2: Sequence[Optional[str]],
proxied: bool,
) -> OutputFiles:
"""
Return an OutputFiles instance. If demultiplex is True, the untrimmed, untrimmed2, out and out2
Expand All @@ -465,8 +467,6 @@ def open_output_files(
]
)
paths = OutputPaths(opener=file_opener.xopen)
if args.info_file:
paths.register(args.info_file)
if args.rest_file:
paths.register(args.rest_file)
if args.wildcard_file:
Expand Down Expand Up @@ -551,8 +551,9 @@ def open_output_files(

outputs = paths.open()
return OutputFiles(
file_opener=file_opener,
proxied=proxied,
rest=outputs.get(args.rest_file),
info=outputs.get(args.info_file),
wildcard=outputs.get(args.wildcard_file),
too_short=too_short,
too_short2=too_short2,
Expand Down Expand Up @@ -811,7 +812,7 @@ def __init__(self, args, paired, adapters, adapters2):
self.adapters = adapters
self.adapters2 = adapters2

def make(self) -> Pipeline: # noqa: C901
def make(self, steps) -> Pipeline: # noqa: C901
"""
Set up a processing pipeline from parsed command-line arguments.
Expand Down Expand Up @@ -887,9 +888,11 @@ def make(self) -> Pipeline: # noqa: C901
pair_filter_mode = (
"any" if self.args.pair_filter is None else self.args.pair_filter
)
pipeline = PairedEndPipeline(modifiers, pair_filter_mode) # type: Any
pipeline = PairedEndPipeline(
modifiers, pair_filter_mode, steps
) # type: Any
else:
pipeline = SingleEndPipeline(modifiers)
pipeline = SingleEndPipeline(modifiers, steps)

# When adapters are being trimmed only in R1 or R2, override the pair filter mode
# as using the default of 'any' would regard all read pairs as untrimmed.
Expand Down Expand Up @@ -1134,12 +1137,24 @@ def main(cmdlineargs, default_outfile=sys.stdout.buffer) -> Statistics:
adapters, adapters2 = adapters_from_args(args)
log_adapters(adapters, adapters2 if paired else None)

pipeline = PipelineMaker(args, paired, adapters, adapters2).make()
adapter_names: List[Optional[str]] = [a.name for a in adapters]
adapter_names2: List[Optional[str]] = [a.name for a in adapters2]
outfiles = open_output_files(
args, default_outfile, file_opener, adapter_names, adapter_names2
args,
default_outfile,
file_opener,
adapter_names,
adapter_names2,
proxied=cores > 1,
)
steps = []
if args.info_file:
step: Any = InfoFileWriter(outfiles.open_text(args.info_file))
if paired:
step = PairedSingleEndStep(step)
steps.append(step)
pipeline = PipelineMaker(args, paired, adapters, adapters2).make(steps)

logger.info(
"Processing %s reads on %d core%s ...",
{False: "single-end", True: "paired-end"}[pipeline.paired],
Expand Down Expand Up @@ -1169,6 +1184,12 @@ def main(cmdlineargs, default_outfile=sys.stdout.buffer) -> Statistics:
logger.error("%s", e)
exit_code = 2 if isinstance(e, CommandLineError) else 1
sys.exit(exit_code)
finally:
# TODO ...
try:
outfiles.close()
except UnboundLocalError:
pass

elapsed = time.time() - start_time
if args.report == "minimal":
Expand Down
78 changes: 71 additions & 7 deletions src/cutadapt/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import errno
import io
import sys
from typing import BinaryIO, Optional, Dict, Tuple, List, Callable
from typing import BinaryIO, Optional, Dict, Tuple, List, Callable, TextIO

import dnaio
from xopen import xopen
Expand Down Expand Up @@ -141,6 +141,29 @@ def open(self) -> InputFiles:
return InputFiles(*files, interleaved=self.interleaved)


class ProxyTextFile:
def __init__(self):
self._buffer = io.BytesIO()
self._file = io.TextIOWrapper(self._buffer)

def write(self, text):
self._file.write(text)

def drain(self) -> bytes:
self._file.flush()
b = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return b

def __getstate__(self):
"""TextIOWrapper cannot be pickled. Just don’t include our state."""
return True # ensure __setstate__ is called

def __setstate__(self, state):
self.__init__()


class OutputFiles:
"""
The attributes are either None or open file-like objects except for demultiplex_out
Expand All @@ -150,6 +173,9 @@ class OutputFiles:

def __init__(
self,
*,
file_opener: FileOpener,
proxied: bool,
out: Optional[BinaryIO] = None,
out2: Optional[BinaryIO] = None,
untrimmed: Optional[BinaryIO] = None,
Expand All @@ -158,7 +184,6 @@ def __init__(
too_short2: Optional[BinaryIO] = None,
too_long: Optional[BinaryIO] = None,
too_long2: Optional[BinaryIO] = None,
info: Optional[BinaryIO] = None,
rest: Optional[BinaryIO] = None,
wildcard: Optional[BinaryIO] = None,
demultiplex_out: Optional[Dict[str, BinaryIO]] = None,
Expand All @@ -167,6 +192,14 @@ def __init__(
combinatorial_out2: Optional[Dict[Tuple[str, str], BinaryIO]] = None,
force_fasta: Optional[bool] = None,
):
self._file_opener = file_opener
# TODO do these actually have to be dicts?
self._binary_files: Dict[str, BinaryIO] = {}
self._text_files: Dict[str, TextIO] = {}
self._proxy_files: Dict[str, ProxyTextFile] = {}
self._proxied = proxied
self.force_fasta = force_fasta

self.out = out
self.out2 = out2
self.untrimmed = untrimmed
Expand All @@ -175,14 +208,36 @@ def __init__(
self.too_short2 = too_short2
self.too_long = too_long
self.too_long2 = too_long2
self.info = info
self.rest = rest
self.wildcard = wildcard
self.demultiplex_out = demultiplex_out
self.demultiplex_out2 = demultiplex_out2
self.combinatorial_out = combinatorial_out
self.combinatorial_out2 = combinatorial_out2
self.force_fasta = force_fasta

def open_text(self, path):
if path in self._binary_files:
raise "duplicate path" # TODO
# TODO
# - serial runner needs only text_file
# - parallel runner needs binary_file and proxy_file
# split into SerialOutputFiles and ParallelOutputFiles?
if self._proxied:
binary_file = self._file_opener.xopen(path, "wb")
self._binary_files[path] = binary_file
proxy_file = ProxyTextFile()
self._proxy_files[path] = proxy_file
return proxy_file
else:
text_file = self._file_opener.xopen(path, "wt")
self._text_files[path] = text_file
return text_file

def binary_files(self):
return list(self._binary_files.values())

def proxy_files(self) -> List[ProxyTextFile]:
return list(self._proxy_files.values())

def __iter__(self):
for f in [
Expand All @@ -194,7 +249,6 @@ def __iter__(self):
self.too_short2,
self.too_long,
self.too_long2,
self.info,
self.rest,
self.wildcard,
]:
Expand All @@ -210,12 +264,17 @@ def __iter__(self):
for f in outs.values():
assert f is not None
yield f
yield from self._binary_files.values()

def as_bytesio(self) -> "OutputFiles":
"""
Create a new OutputFiles instance that has BytesIO instances for each non-None output file
"""
result = OutputFiles(force_fasta=self.force_fasta)
result = OutputFiles(
file_opener=self._file_opener,
proxied=False,
force_fasta=self.force_fasta,
)
for attr in (
"out",
"out2",
Expand All @@ -225,7 +284,6 @@ def as_bytesio(self) -> "OutputFiles":
"too_short2",
"too_long",
"too_long2",
"info",
"rest",
"wildcard",
):
Expand All @@ -249,6 +307,12 @@ def close(self) -> None:
if f is sys.stdout or f is sys.stdout.buffer:
continue
f.close()
if self._proxied:
for f in self._binary_files.values():
f.close()
else:
for f in self._text_files.values():
f.close()


class OpenedOutputs:
Expand Down
Loading

0 comments on commit b6e256c

Please sign in to comment.