diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 739ab296..5e70337a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: args: [--safe, --line-length=100, --preview] - id: black-jupyter args: [--safe, --line-length=100, --preview] - language_version: python3.9 + language_version: python3 - repo: https://github.com/pycqa/docformatter rev: v1.7.5 diff --git a/strax/__init__.py b/strax/__init__.py index ba466a61..18955168 100644 --- a/strax/__init__.py +++ b/strax/__init__.py @@ -20,6 +20,7 @@ from .mailbox import * from .processor import * +from .processors import * from .context import * from .run_selection import * from .corrections import * diff --git a/strax/chunk.py b/strax/chunk.py index 0cc79810..df5275a1 100644 --- a/strax/chunk.py +++ b/strax/chunk.py @@ -420,3 +420,44 @@ def _update_subruns_in_chunk(chunks): else: subruns[subrun_id] = subrun_start_end return subruns + + +@export +class Rechunker: + """Helper class for rechunking. + + Send in chunks via receive, which returns either None (no chunk to send) or a chunk to send. + + Don't forget a final call to .flush() to get any final data out! + + """ + + def __init__(self, rechunk=False, run_id=None): + self.rechunk = rechunk + self.is_superrun = run_id and run_id.startswith("_") and not run_id.startswith("__") + self.run_id = run_id + + self.cache = None + + def receive(self, chunk): + if self.is_superrun: + chunk = strax.transform_chunk_to_superrun_chunk(self.run_id, chunk) + if not self.rechunk: + # We aren't rechunking + return chunk + if self.cache: + # We have an old chunk, so we need to concatenate + chunk = strax.Chunk.concatenate([self.cache, chunk]) + if chunk.data.nbytes >= chunk.target_size_mb * 1e6: + # Enough data to send a new chunk! + self.cache = None + return chunk + else: + # Not enough data yet, so we cache the chunk + self.cache = chunk + return None + + def flush(self): + result = self.cache + self.cache = None + return result diff --git a/strax/context.py b/strax/context.py index 28dfb3c7..193c45c9 100644 --- a/strax/context.py +++ b/strax/context.py @@ -191,7 +191,11 @@ class Context: _run_defaults_cache: dict storage: ty.List[strax.StorageFrontend] - def __init__(self, storage=None, config=None, register=None, register_all=None, **kwargs): + processors: ty.Mapping[str, strax.BaseProcessor] + + def __init__( + self, storage=None, config=None, register=None, register_all=None, processors=None, **kwargs + ): """Create a strax context. :param storage: Storage front-ends to use. Can be: @@ -202,7 +206,9 @@ def __init__(self, storage=None, config=None, register=None, register_all=None, applied to plugins :param register: plugin class or list of plugin classes to register :param register_all: module for which all plugin classes defined in it - will be registered. + will be registered. + :param processors: A mapping of processor names to classes to use for + data processing. Any additional kwargs are considered Context-specific options; see Context.takes_config. @@ -226,12 +232,32 @@ def __init__(self, storage=None, config=None, register=None, register_all=None, if register is not None: self.register(register) + if processors is None: + processors = strax.PROCESSORS + + if isinstance(processors, str): + processors = [processors] + + if isinstance(processors, (list, tuple)): + ps = {} + for processor in processors: + if isinstance(processor, str) and processor in strax.PROCESSORS: + ps[processor] = strax.PROCESSORS[processor] + elif isinstance(processor, strax.BaseProcessor): + ps[processor.__name__] = processor + else: + raise ValueError(f"Unknown processor {processor}") + processors = ps + + self.processors = processors + def new_context( self, storage=tuple(), config=None, register=None, register_all=None, + processors=None, replace=False, **kwargs, ): @@ -255,7 +281,7 @@ def new_context( config = strax.combine_configs(self.config, config, mode="update") kwargs = strax.combine_configs(self.context_config, kwargs, mode="update") - new_c = Context(storage=storage, config=config, **kwargs) + new_c = Context(storage=storage, config=config, processors=processors, **kwargs) if not replace: new_c._plugin_class_registry = self._plugin_class_registry.copy() new_c.register_all(register_all) @@ -1434,7 +1460,7 @@ def to_absolute_time_range( def get_iter( self, run_id: str, - targets: ty.Union[ty.Tuple[str], ty.List[str]], + targets, save=tuple(), max_workers=None, time_range=None, @@ -1449,6 +1475,7 @@ def get_iter( progress_bar=True, multi_run_progress_bar=True, _chunk_number=None, + processor=None, **kwargs, ) -> ty.Iterator[strax.Chunk]: """Compute target for run_id and iterate over results. @@ -1516,8 +1543,17 @@ def get_iter( if k.startswith("_temp"): del self._plugin_class_registry[k] + if processor is None: + processor = list(self.processors)[0] + + if isinstance(processor, str): + processor = self.processors[processor] + + if not hasattr(processor, "iter"): + raise ValueError("Processors must implement a iter methed.") + seen_a_chunk = False - generator = strax.ThreadedMailboxProcessor( + generator = processor( components, max_workers=max_workers, allow_shm=self.context_config["allow_shm"], @@ -2542,8 +2578,7 @@ def add_method(cls, f): :param multi_run_progress_bar: Display a progress bar for loading multiple runs """ -get_docs = ( - """ +get_docs = """ :param run_id: run id to get :param targets: list/tuple of strings of data type names to get :param ignore_errors: Return the data for the runs that successfully loaded, even if some runs @@ -2563,9 +2598,10 @@ def add_method(cls, f): :param run_id_as_bytes: Boolean if true uses byte string instead of an unicode string added to a multi-run array. This can save a lot of memory when loading many runs. +:param processor: Name of the processor to use. If not specified, the + first processor from the context's processor list is used. """ - + select_docs -) +get_docs += select_docs for attr in dir(Context): attr_val = getattr(Context, attr) diff --git a/strax/processor.py b/strax/processor.py index fb53e427..5c752d34 100644 --- a/strax/processor.py +++ b/strax/processor.py @@ -1,326 +1,3 @@ -from concurrent import futures -from functools import partial -import logging -import typing as ty -import os -import sys -from concurrent.futures import ProcessPoolExecutor - -import numpy as np - -import strax - -export, __all__ = strax.exporter() - -try: - import npshmex - - SHMExecutor = npshmex.ProcessPoolExecutor - npshmex.register_array_wrapper(strax.Chunk, "data") -except ImportError: - # This is allowed to fail, it only crashes if allow_shm = True - SHMExecutor = None - - -@export -class ProcessorComponents(ty.NamedTuple): - """Specification to assemble a processor.""" - - plugins: ty.Dict[str, strax.Plugin] - loaders: ty.Dict[str, ty.Callable] - loader_plugins: ty.Dict[str, strax.Plugin] # Required for inline ParallelSource plugin. - savers: ty.Dict[str, ty.List[strax.Saver]] - targets: ty.Tuple[str] - - -class MailboxDict(dict): - def __init__(self, *args, lazy=False, **kwargs): - super().__init__(*args, **kwargs) - self.lazy = lazy - - def __missing__(self, key): - res = self[key] = strax.Mailbox(name=key + "_mailbox", lazy=self.lazy) - return res - - -@export -class ThreadedMailboxProcessor: - mailboxes: ty.Dict[str, strax.Mailbox] - - def __init__( - self, - components: ProcessorComponents, - allow_rechunk=True, - allow_shm=False, - allow_multiprocess=False, - allow_lazy=True, - max_workers=None, - max_messages=4, - timeout=60, - is_superrun=False, - ): - self.log = logging.getLogger(self.__class__.__name__) - self.components = components - - self.log.debug("Processor components are: " + str(components)) - - if allow_multiprocess and os.name == "nt": - print("You're on Windows! Multiprocessing disabled, here be dragons.") - allow_multiprocess = False - - if max_workers in [None, 1]: - # Disable the executors: work in one process. - # Each plugin works completely in its own thread. - self.process_executor = self.thread_executor = None - lazy = allow_lazy - else: - lazy = False - # Use executors for parallelization of computations. - self.thread_executor = futures.ThreadPoolExecutor(max_workers=max_workers) - - mp_plugins = {d: p for d, p in components.plugins.items() if p.parallel == "process"} - if allow_multiprocess and len(mp_plugins): - _proc_ex = ProcessPoolExecutor - if allow_shm: - if SHMExecutor is None: - raise RuntimeError( - "You must install npshmex to enable shm transfer of numpy arrays." - ) - _proc_ex = SHMExecutor - self.process_executor = _proc_ex(max_workers=max_workers) - - # Combine as many plugins /savers as possible in one process - # TODO: more intelligent start determination, multiple starts - start_from = list(mp_plugins.keys())[ - int(np.argmin([len(p.depends_on) for p in mp_plugins.values()])) - ] - components = strax.ParallelSourcePlugin.inline_plugins( - components, start_from, log=self.log - ) - self.components = components - self.log.debug("Altered components for multiprocessing: " + str(components)) - else: - self.process_executor = self.thread_executor # type: ignore - - # Figure which outputs - # - we should exclude from the flow control in lazy mode, - # because they are produced but not required. - # - we should discard (produced but neither required not saved) - produced = set(components.loaders) - required = set(components.targets) - # Do not just take keys from savers, perhaps some keys - # have no savers are under them (see #444) - saved = set([k for k, v in components.savers.items() if v]) - - for p in components.plugins.values(): - produced.update(p.provides) - required.update(p.depends_on) - to_flow_freely = produced - required - to_discard = to_flow_freely - saved - self.log.debug( - f"to_flow_freely {to_flow_freely}" - f"to_discard {to_discard}" - f"produced {produced}" - f"required {required}" - f"saved {saved}" - ) - - self.mailboxes = MailboxDict(lazy=lazy) - - for d, loader in components.loaders.items(): - assert d not in components.plugins - # If paralellizing, use threads for loading - # the decompressor releases the gil, and we have a lot - # of data transfer to do - self.mailboxes[d].add_sender(loader(executor=self.thread_executor), name=f"load:{d}") - - multi_output_seen: ty.List[strax.Plugin] = [] - for d, p in components.plugins.items(): - if p in multi_output_seen: - continue - - if p.__class__ in [mp_seen.__class__ for mp_seen in multi_output_seen]: - raise ValueError( - "A multi-output plugin is registered with different " - "instances for its provided data_types!" - ) - - executor = None - if p.parallel == "process": - executor = self.process_executor - elif p.parallel: - executor = self.thread_executor # type: ignore - - if p.multi_output: - multi_output_seen.append(p) - - # Create temp mailbox that receives multi-output dicts - # and sends them forth to other mailboxes - mname = p.__class__.__name__ + "_divide_outputs" - self.mailboxes[mname].add_sender( - p.iter( - iters={dep: self.mailboxes[dep].subscribe() for dep in p.depends_on}, - executor=executor, - ), - name=f"divide_outputs:{d}", - ) - - # If we have a plugin with double dependency both outputs - # of a multioutput-plugin are required. Hence flow-freely - # is empty an needs to be updated here: - provided_data_types = set(p.provides) - reader_data_types = set(strax.to_str_tuple(d)) - double_dependency = provided_data_types - reader_data_types - to_flow_freely |= double_dependency - self.log.debug(f"Updating flow freely for {mname} to be {to_flow_freely}") - - self.mailboxes[mname].add_reader( - partial( - strax.divide_outputs, - lazy=lazy, - # make sure to subscribe the outputs of the mp_plugins - mailboxes={k: self.mailboxes[k] for k in p.provides}, - flow_freely=to_flow_freely, - outputs=p.provides, - ) - ) - - else: - self.mailboxes[d].add_sender( - p.iter( - iters={dep: self.mailboxes[dep].subscribe() for dep in p.depends_on}, - executor=executor, - ), - name=f"build:{d}", - ) - - dtypes_built = {d: p for p in components.plugins.values() for d in p.provides} - for d, savers in components.savers.items(): - for s_i, saver in enumerate(savers): - if d in dtypes_built: - can_drive = not lazy - rechunk = dtypes_built[d].can_rechunk(d) and allow_rechunk - else: - # This is storage conversion mode - # TODO: Don't know how to get this info, for now, - # be conservative and don't rechunk - can_drive = True - rechunk = is_superrun and allow_rechunk - - self.mailboxes[d].add_reader( - partial( - saver.save_from, - rechunk=rechunk, - # If paralellizing, use threads for saving - # the compressor releases the gil, - # and we have a lot of data transfer to do - executor=self.thread_executor, - ), - can_drive=can_drive, - name=f"save_{s_i}:{d}", - ) - - # For multi-output plugins, an output may be neither saved nor - # required, and thus has to be discarded. - # This should happen rarely in production (when you actually - # care about the data, you will be saving it) - def discarder(source): - for _ in source: - pass - - for d in to_discard: - self.mailboxes[d].add_reader(discarder, name=f"discard_{d}") - - # Set to preferred number of maximum messages - # TODO: may not work if plugins are inlined?? - for d, m in self.mailboxes.items(): - m.max_messages = max_messages - m.timeout = timeout - if d in components.plugins: - max_m = components.plugins[d].max_messages - if max_m is not None: - m.max_messages = max_m - - # Remove defaultdict-like behaviour; all mailboxes should - # have been made by now. See #444 - self.mailboxes = dict(self.mailboxes) - self.log.debug( - f"Created the following mailboxes: {self.mailboxes} with the " - f"following threads: {[(d, m._threads) for d, m in self.mailboxes.items()]}" - ) - - def iter(self): - target = self.components.targets[0] - final_generator = self.mailboxes[target].subscribe() - - self.log.debug("Starting threads") - for m in self.mailboxes.values(): - self.log.debug(f"start {m}") - m.start() - - self.log.debug(f"Yielding {target}") - traceback, exc, reason = None, None, None - - try: - yield from final_generator - - # GeneratorExit results from exception in caller - # (on garbage collection, .close() is called, see PEP342) - except (Exception, GeneratorExit) as e: - self.log.fatal(f"Target Mailbox ({target}) killed, exception {type(e)}, message {e}") - if isinstance(e, strax.MailboxKilled): - _, exc, traceback = reason = e.args[0] - else: - exc = e - reason = (e.__class__, e, sys.exc_info()[2]) - traceback = reason[2] - - # We will reraise it in just a moment... - - if exc is not None: - if isinstance(exc, GeneratorExit): - print("Main generator exited irregularly?!") - reason[2] = ( - "Hm, interesting. Most likely an exception was thrown " - "outside strax, but we did not handle it properly." - ) - - # Kill the mailboxes - for m in self.mailboxes.values(): - if m != target: - self.log.debug(f"Killing {m}") - m.kill(upstream=True, reason=reason) - - self.log.debug("Closing threads") - for m in self.mailboxes.values(): - m.cleanup() - self.log.debug("Closing threads completed") - - self.log.debug("Closing executors") - if self.thread_executor is not None: - self.thread_executor.shutdown(wait=True) - if self.process_executor not in [None, self.thread_executor]: - self.process_executor.shutdown(wait=True) - self.log.debug("Closing executors completed") - - if exc is not None: - # Reraise exception. This is outside the except block - # to avoid the 'during handling of this exception, another - # exception occurred' stuff from confusing the traceback - # which is printed for the user - self.log.debug("Reraising exception") - raise exc.with_traceback(traceback) - - # Check the savers for any exception that occurred during saving - # These are thrown back to the mailbox, but if that has already closed - # it doesn't trigger a crash... - # TODO: add savers inlined by parallelsourceplugin - # TODO: need to look at plugins too if we ever implement true - # multi-target mode - for k, saver_list in self.components.savers.items(): - for s in saver_list: - if s.got_exception: - self.log.fatal(f"Caught error while saving {k}!") - raise s.got_exception - - self.log.debug("Processing finished") +# flake8: noqa +# Legacy import, used in a single place in straxen. +from .processors.threaded_mailbox import SHMExecutor diff --git a/strax/processors/__init__.py b/strax/processors/__init__.py new file mode 100644 index 00000000..346fa5d3 --- /dev/null +++ b/strax/processors/__init__.py @@ -0,0 +1,14 @@ +from .base import * +from .threaded_mailbox import * +from .single_thread import * + +# This is redundant with the star-imports above, but some flake8 +# versions require this +from .threaded_mailbox import ThreadedMailboxProcessor +from .single_thread import SingleThreadProcessor + +PROCESSORS = { + "default": ThreadedMailboxProcessor, + "threaded_mailbox": ThreadedMailboxProcessor, + "single_thread": SingleThreadProcessor, +} diff --git a/strax/processors/base.py b/strax/processors/base.py new file mode 100644 index 00000000..4f747ab4 --- /dev/null +++ b/strax/processors/base.py @@ -0,0 +1,30 @@ +import logging +import typing as ty + +import strax + +export, __all__ = strax.exporter() + + +@export +class ProcessorComponents(ty.NamedTuple): + """Specification to assemble a processor.""" + + plugins: ty.Dict[str, strax.Plugin] + loaders: ty.Dict[str, ty.Callable] + # Required for inline ParallelSource plugin. + loader_plugins: ty.Dict[str, strax.Plugin] + savers: ty.Dict[str, ty.List[strax.Saver]] + targets: ty.Tuple[str] + + +@export +class BaseProcessor: + components: ProcessorComponents + + def __init__(self, components: ProcessorComponents, **kwargs): + self.log = logging.getLogger(self.__class__.__name__) + self.components = components + + def iter(self): + raise NotImplementedError diff --git a/strax/processors/post_office.py b/strax/processors/post_office.py new file mode 100644 index 00000000..dc4c8318 --- /dev/null +++ b/strax/processors/post_office.py @@ -0,0 +1,298 @@ +"""Single-threaded message bus / mailbox-system replacement code.""" + +import logging +import time +import typing as ty + +log = logging.getLogger("strax.post_office") + + +class Spy: + """Template for spies; a spy that does nothing.""" + + def receive(self, msg): + """Called when a new message is produced.""" + pass + + def close(self): + """Called when the topic is exhausted.""" + pass + + def kill(self, reason): + """Called when closing the spy prematurely, e.g. during exception handling.""" + self.close() + + +class PostOffice: + """A single-threaded message bus that uses iterators. + + This allows you to register producers and create readers of messages + of different topics. You can also register 'spies', which will get each + message just after it has been produced. + + Notes: + * The readers are iterators (technically generators), and the producers + should be iterable as well (probably you implement them as generators). + * Only one producer can be registered for each topic. + * Producers may produce (topic -> message) dicts and thereby feed multiple + topics at once. + * If multiple readers are registered for the same topic, PostOffice + will save messages not yet read by all readers. + * If you create a reader and never iterate it to completion, messages + will be saved until the PostOffice is garbage collected. + * We only call .close() on a spy when the topic is exhausted. To close + all spies prematurely (e.g. to handle an exception), call + .kill_spies(). + + """ + + # Time tracking + + #: Currently active topic/code + active_topic: str = "" + #: Dict of actiactive code -> seconds spent in that code + time_spent: ty.Dict[ty.Tuple[str], float] + + # Internal state + + #: Set of topics that have been exhausted (no more messages will come) + _exhausted_topics: ty.Set[str] + + #: Set of topics that are multi output + #: (i.e. the producer makes topic -> message dicts) + _multi_output_topics: ty.Set[str] + + # Dict: topic -> list with (msg_number, msg) + _saved_mail: ty.Dict[str, ty.List[ty.Tuple[int, ty.Any]]] + # Dict: topic -> list of spies + _spies: ty.Dict[str, ty.List[Spy]] + # Dict: topic-> iterator that produces messages + _producers: ty.Dict[str, ty.Iterable] + # Dict: topic -> last message produced + _last_msg_produced: ty.Dict[str, int] + # Dict: topic -> reader_name -> last message number recieved + _last_msg_read: ty.Dict[str, ty.Dict[str, int]] + # Dict: topic -> list of readers that are done + _readers_done: ty.Dict[str, ty.List[str]] + + # (Factoring the above variables into a Topic class didn't work for me. + # Multi-output producers exist, topic would be != topic_name, etc..) + + def __init__(self): + self.time_spent = dict() + self._exhausted_topics = set() + self._multi_output_topics = set() + + self._saved_mail = dict() + self._spies = dict() + self._producers = dict() + self._last_msg_produced = dict() + self._last_msg_read = dict() + self._readers_done = dict() + + self._count_time("") + + @property + def topic_names(self): + return list(self._saved_mail.keys()) + + def state(self): + """Get a multi-line string representing the current state, suitable for printing or + logging.""" + result = [] + for topic in self.topic_names: + result.append( + f"Topic {topic}\n" + f" Saved mail: {self._saved_mail[topic]}\n" + f" Last produced: {self._last_msg_produced[topic]}\n" + f" Readers recieved: {self._last_msg_read[topic]}\n" + f" Readers done: {self._readers_done[topic]}\n" + f" Time spent: {self.time_spent.get(topic, None)}\n" + ) + also_spent = {k: v for k, v in self.time_spent.items() if k not in self.topic_names} + result.append(f"Also spent time on: {also_spent}") + result.append(f"Total time spent: {sum(self.time_spent.values())}") + return "\n".join(result) + + def register_producer(self, iterator: ty.Iterator[ty.Any], topic: ty.Union[str, ty.Tuple[str]]): + """Register iterator as the source of messages for topic. + + If topic is a tuple of strings, the iterator should produce (topic -> message) dicts, with + every registered topic in the dict. + + """ + if isinstance(topic, tuple): + if len(topic) == 1: + # Syntax sugar, just a single-output producer + topic = topic[0] + else: + # Multi-output producer, recurse + for sub_topic in topic: + self._multi_output_topics.add(sub_topic) + self.register_producer(iterator, sub_topic) + return + assert isinstance(topic, str) + if topic in self._producers: + raise RuntimeError(f"{topic} already has a producer") + self._register_topic(topic) + self._producers[topic] = iterator + + def _register_topic(self, topic: str): + if topic in self._saved_mail: + return + assert isinstance(topic, str) + self._saved_mail[topic] = [] + self._spies[topic] = [] + self._last_msg_read[topic] = dict() + self._last_msg_produced[topic] = -1 + self._readers_done[topic] = [] + + def register_spy(self, spy: Spy, topic: str): + """Register spy to recieve all messages on topic. + + spy.recieve(msg) will be called for each message, and spy.close() when the topic is + exhausted. + + """ + self._register_topic(topic) + self._spies[topic].append(spy) + + def get_iter(self, topic: str, reader: str): + """Return iterator over messages with topic, for a named reader (usually readers are named + after the messages they produce)""" + self._register_topic(topic) + # Register subscriber + self._last_msg_read[topic][reader] = -1 + # Return generator + return self._read(topic, reader) + + def kill_spies(self, reason=None): + """Close all spies immediately, e.g. during exception handling. + + Reason is passed to spy.kill. + + """ + for spies in self._spies.values(): + for spy in spies: + spy.kill(reason) + + def _count_time(self, topic): + """Start counting time towards topic.""" + now = time.time() + if self.active_topic: + self.time_spent.setdefault(self.active_topic, 0) + self.time_spent[self.active_topic] += now - self._last_switch_time + self.active_topic = topic + self._last_switch_time = now + + def _read(self, topic, reader): + """Actual generator producing messages for reader on topic.""" + msg_number = 0 + while self._message_may_come(topic, msg_number): + # Try to get this message from the cache + for _msg_i, result in self._saved_mail[topic]: + if _msg_i == msg_number: + break + else: + try: + # We have to produce a new message + result = self._fetch_new(topic) + except StopIteration: + # Message actually won't come, exit the while loop. + # (The while condition wasn't triggered because + # the producer only just realized the topic is exhausted) + break + log.debug(f"{reader} receiving message {result}, number {msg_number} of {topic}") + # Note receipt before yielding, so we can clear unnecessary + # messages from our storage (if possible) before we lose control. + self._ack_reader_recieved(reader, topic, msg_number) + # Yield via popping a container to avoid retaining a reference to + # the result https://stackoverflow.com/questions/7133179 + result = [result] + # Return control to the reader, and start counting time on their + # budget + self._count_time(reader) + yield result.pop() + # Reader wants more -- back to working on the topic. + # Look for the next message + self._count_time(topic) + msg_number += 1 + + # We get here if the topic is exhausted & we have read all messages; + # just do a final check for debugging purposes. + self._readers_done[topic].append(reader) + if len(self._readers_done[topic]) == len(self._last_msg_read[topic]): + # All readers are done. Just to be sure, check that we have + # no more messages in the cache. + assert not self._saved_mail[topic] + + def _message_may_come(self, topic, msg_number): + """Return True if topic is guaranteed to never produce msg_number.""" + return not (topic in self._exhausted_topics and msg_number > self._last_msg_produced[topic]) + + def _fetch_new(self, topic): + """Fetch a new message from the producer of topic. + + Raises StopIteration if the topic is exhausted so a new message will never come. + + """ + if topic not in self._producers: + raise RuntimeError(f"No producer registered for {topic}") + try: + msg = next(self._producers[topic]) + except StopIteration: + self._ack_topic_exhausted(topic) + # reraise to end the generator in _read + raise StopIteration + + if topic not in self._multi_output_topics: + log.debug(f"Got simple message {msg} for topic {topic}") + # Simple message, just ack and return to caller + self._ack_msg_produced(msg, topic) + return msg + + # msg is a dict with messages for different topics + assert isinstance(msg, dict) + for sub_msg_topic, sub_msg in msg.items(): + if sub_msg_topic == topic: + # This is what our caller wants + desired_sub_msg = sub_msg + log.debug(f"Got submessage {sub_msg} for sub topic {sub_msg_topic}") + self._ack_msg_produced(sub_msg, sub_msg_topic) + return desired_sub_msg + + def _ack_msg_produced(self, msg, topic): + """Note that msg of topic has been produced.""" + assert topic not in self._exhausted_topics + + self._last_msg_produced[topic] += 1 + + if len(self._last_msg_read.get(topic, [])): + # Someone is interested in this topic, so save the message. + # (If there is only one reader, _ack_reader_recieved will clean + # up this message before we yield control to that reader.) + self._saved_mail[topic].append((self._last_msg_produced[topic], msg)) + + # Deliver the message to the spies (savers/monitors) + for spy in self._spies[topic]: + spy.receive(msg) + + def _ack_reader_recieved(self, reader, topic, msg_number): + """Acknowledge reader got msg_number of topic.""" + # Record receipt + assert self._last_msg_read[topic][reader] == msg_number - 1 + self._last_msg_read[topic][reader] = msg_number + # Keep only messages someone has not yet recieved + everyone_got = min(self._last_msg_read[topic].values()) + log.debug(f"Cleaning out {topic} up to {everyone_got}") + self._saved_mail[topic] = [ + (msg_number, msg) + for msg_number, msg in self._saved_mail[topic] + if msg_number > everyone_got + ] + + def _ack_topic_exhausted(self, topic): + """Take note that topic is exhausted, no new messages will come.""" + for spy in self._spies[topic]: + spy.close() + self._exhausted_topics.add(topic) diff --git a/strax/processors/single_thread.py b/strax/processors/single_thread.py new file mode 100644 index 00000000..77929a47 --- /dev/null +++ b/strax/processors/single_thread.py @@ -0,0 +1,103 @@ +import typing as ty + +from .base import BaseProcessor, ProcessorComponents +from .post_office import PostOffice, Spy + + +import strax + +export, __all__ = strax.exporter() + + +@export +class SingleThreadProcessor(BaseProcessor): + def __init__( + self, components: ProcessorComponents, allow_rechunk=True, is_superrun=False, **kwargs + ): + super().__init__(components, allow_rechunk=allow_rechunk, is_superrun=is_superrun, **kwargs) + + self.log.debug("Processor components are: " + str(components)) + + # Do not use executors: work in one thread in one process + self.process_executor = self.thread_executor = None + + self.post_office = PostOffice() + + for d, loader in components.loaders.items(): + assert d not in components.plugins + self.post_office.register_producer(loader(executor=self.thread_executor), topic=d) + + plugins_seen: ty.List[strax.Plugin] = [] + for d, p in components.plugins.items(): + # Multi-output plugins are listed multiple times in components.plugins; + # ensure we only process each plugin once. + if p in plugins_seen: + continue + plugins_seen.append(p) + + self.post_office.register_producer( + p.iter(iters={dep: self.post_office.get_iter(dep, d) for dep in p.depends_on}), + topic=strax.to_str_tuple(p.provides), + ) + + dtypes_built = {d: p for p in components.plugins.values() for d in p.provides} + for d, savers in components.savers.items(): + for s_i, saver in enumerate(savers): + if d in dtypes_built: + rechunk = dtypes_built[d].can_rechunk(d) and allow_rechunk + else: + rechunk = is_superrun and allow_rechunk + + self.post_office.register_spy(SaverSpy(saver, rechunk=rechunk), topic=d) + + def iter(self): + target = self.components.targets[0] + final_generator = self.post_office.get_iter(topic=target, reader="FINAL") + + self.log.debug(f"Yielding {target}") + + try: + yield from final_generator + + except Exception: + # Exception in one of the producers. Close savers (they will record + # the exception from sys.exc_info()) then reraise. + self.log.fatal(f"Exception during processing, closing savers and reraising") + self.post_office.kill_spies() + raise + + except GeneratorExit: + self.log.fatal( + "Exception in code that called the processor: detected " + "GeneratorExit from python shutting down. " + "Closing savers and exiting." + ) + # Strax savers look at sys.exc_info(). Having only "GeneratorExit" + # there is unhelpful.. this should set it to something better: + try: + raise RuntimeError("Exception in caller, see log for details") + except RuntimeError: + self.post_office.kill_spies() + + self.log.debug("Processing finished") + + +class SaverSpy(Spy): + """Spy that saves messages to a saver.""" + + def __init__(self, saver, rechunk=False): + self.saver = saver + self.rechunker = strax.Rechunker(rechunk, self.saver.md["run_id"]) + self.chunk_number = 0 + + def receive(self, chunk): + self._save_chunk(self.rechunker.receive(chunk)) + + def _save_chunk(self, chunk): + if chunk is not None: + self.saver.save(chunk, self.chunk_number) + self.chunk_number += 1 + + def close(self): + self._save_chunk(self.rechunker.flush()) + self.saver.close() diff --git a/strax/processors/threaded_mailbox.py b/strax/processors/threaded_mailbox.py new file mode 100644 index 00000000..10f43052 --- /dev/null +++ b/strax/processors/threaded_mailbox.py @@ -0,0 +1,318 @@ +from concurrent import futures +from functools import partial +import logging +import typing as ty +import os +import sys +from concurrent.futures import ProcessPoolExecutor + +import numpy as np +from .base import ProcessorComponents, BaseProcessor + +import strax + +export, __all__ = strax.exporter() + + +class MailboxDict(dict): + def __init__(self, *args, lazy=False, **kwargs): + super().__init__(*args, **kwargs) + self.lazy = lazy + + def __missing__(self, key): + res = self[key] = strax.Mailbox(name=key + "_mailbox", lazy=self.lazy) + return res + + +try: + import npshmex + + SHMExecutor = npshmex.ProcessPoolExecutor + npshmex.register_array_wrapper(strax.Chunk, "data") +except ImportError: + # This is allowed to fail, it only crashes if allow_shm = True + SHMExecutor = None +__all__.append("SHMExecutor") + + +@export +class ThreadedMailboxProcessor(BaseProcessor): + mailboxes: ty.Dict[str, strax.Mailbox] + + def __init__( + self, + components: ProcessorComponents, + allow_rechunk=True, + allow_shm=False, + allow_multiprocess=False, + allow_lazy=True, + max_workers=None, + max_messages=4, + timeout=60, + is_superrun=False, + ): + self.log = logging.getLogger(self.__class__.__name__) + self.components = components + + self.log.debug("Processor components are: " + str(components)) + + if allow_multiprocess and os.name == "nt": + print("You're on Windows! Multiprocessing disabled, here be dragons.") + allow_multiprocess = False + + if max_workers in [None, 1]: + # Disable the executors: work in one process. + # Each plugin works completely in its own thread. + self.process_executor = self.thread_executor = None + lazy = allow_lazy + else: + lazy = False + # Use executors for parallelization of computations. + self.thread_executor = futures.ThreadPoolExecutor(max_workers=max_workers) + + mp_plugins = {d: p for d, p in components.plugins.items() if p.parallel == "process"} + if allow_multiprocess and len(mp_plugins): + _proc_ex = ProcessPoolExecutor + if allow_shm: + if SHMExecutor is None: + raise RuntimeError( + "You must install npshmex to enable shm transfer of numpy arrays." + ) + _proc_ex = SHMExecutor + self.process_executor = _proc_ex(max_workers=max_workers) + + # Combine as many plugins /savers as possible in one process + # TODO: more intelligent start determination, multiple starts + start_from = list(mp_plugins.keys())[ + int(np.argmin([len(p.depends_on) for p in mp_plugins.values()])) + ] + components = strax.ParallelSourcePlugin.inline_plugins( + components, start_from, log=self.log + ) + self.components = components + self.log.debug("Altered components for multiprocessing: " + str(components)) + else: + self.process_executor = self.thread_executor # type: ignore + + # Figure which outputs + # - we should exclude from the flow control in lazy mode, + # because they are produced but not required. + # - we should discard (produced but neither required not saved) + produced = set(components.loaders) + required = set(components.targets) + # Do not just take keys from savers, perhaps some keys + # have no savers are under them (see #444) + saved = set([k for k, v in components.savers.items() if v]) + + for p in components.plugins.values(): + produced.update(p.provides) + required.update(p.depends_on) + to_flow_freely = produced - required + to_discard = to_flow_freely - saved + self.log.debug( + f"to_flow_freely {to_flow_freely}" + f"to_discard {to_discard}" + f"produced {produced}" + f"required {required}" + f"saved {saved}" + ) + + self.mailboxes = MailboxDict(lazy=lazy) + + for d, loader in components.loaders.items(): + assert d not in components.plugins + # If paralellizing, use threads for loading + # the decompressor releases the gil, and we have a lot + # of data transfer to do + self.mailboxes[d].add_sender(loader(executor=self.thread_executor), name=f"load:{d}") + + multi_output_seen: ty.List[strax.Plugin] = [] + for d, p in components.plugins.items(): + if p in multi_output_seen: + continue + + if p.__class__ in [mp_seen.__class__ for mp_seen in multi_output_seen]: + raise ValueError( + "A multi-output plugin is registered with different " + "instances for its provided data_types!" + ) + + executor = None + if p.parallel == "process": + executor = self.process_executor + elif p.parallel: + executor = self.thread_executor # type: ignore + + if p.multi_output: + multi_output_seen.append(p) + + # Create temp mailbox that receives multi-output dicts + # and sends them forth to other mailboxes + mname = p.__class__.__name__ + "_divide_outputs" + self.mailboxes[mname].add_sender( + p.iter( + iters={dep: self.mailboxes[dep].subscribe() for dep in p.depends_on}, + executor=executor, + ), + name=f"divide_outputs:{d}", + ) + + # If we have a plugin with double dependency both outputs + # of a multioutput-plugin are required. Hence flow-freely + # is empty an needs to be updated here: + provided_data_types = set(p.provides) + reader_data_types = set(strax.to_str_tuple(d)) + double_dependency = provided_data_types - reader_data_types + to_flow_freely |= double_dependency + self.log.debug(f"Updating flow freely for {mname} to be {to_flow_freely}") + + self.mailboxes[mname].add_reader( + partial( + strax.divide_outputs, + lazy=lazy, + # make sure to subscribe the outputs of the mp_plugins + mailboxes={k: self.mailboxes[k] for k in p.provides}, + flow_freely=to_flow_freely, + outputs=p.provides, + ) + ) + + else: + self.mailboxes[d].add_sender( + p.iter( + iters={dep: self.mailboxes[dep].subscribe() for dep in p.depends_on}, + executor=executor, + ), + name=f"build:{d}", + ) + + dtypes_built = {d: p for p in components.plugins.values() for d in p.provides} + for d, savers in components.savers.items(): + for s_i, saver in enumerate(savers): + if d in dtypes_built: + can_drive = not lazy + rechunk = dtypes_built[d].can_rechunk(d) and allow_rechunk + else: + # This is storage conversion mode + # TODO: Don't know how to get this info, for now, + # be conservative and don't rechunk + can_drive = True + rechunk = is_superrun and allow_rechunk + + self.mailboxes[d].add_reader( + partial( + saver.save_from, + rechunk=rechunk, + # If paralellizing, use threads for saving + # the compressor releases the gil, + # and we have a lot of data transfer to do + executor=self.thread_executor, + ), + can_drive=can_drive, + name=f"save_{s_i}:{d}", + ) + + # For multi-output plugins, an output may be neither saved nor + # required, and thus has to be discarded. + # This should happen rarely in production (when you actually + # care about the data, you will be saving it) + def discarder(source): + for _ in source: + pass + + for d in to_discard: + self.mailboxes[d].add_reader(discarder, name=f"discard_{d}") + + # Set to preferred number of maximum messages + # TODO: may not work if plugins are inlined?? + for d, m in self.mailboxes.items(): + m.max_messages = max_messages + m.timeout = timeout + if d in components.plugins: + max_m = components.plugins[d].max_messages + if max_m is not None: + m.max_messages = max_m + + # Remove defaultdict-like behaviour; all mailboxes should + # have been made by now. See #444 + self.mailboxes = dict(self.mailboxes) + self.log.debug( + f"Created the following mailboxes: {self.mailboxes} with the " + f"following threads: {[(d, m._threads) for d, m in self.mailboxes.items()]}" + ) + + def iter(self): + target = self.components.targets[0] + final_generator = self.mailboxes[target].subscribe() + + self.log.debug("Starting threads") + for m in self.mailboxes.values(): + self.log.debug(f"start {m}") + m.start() + + self.log.debug(f"Yielding {target}") + traceback, exc, reason = None, None, None + + try: + yield from final_generator + + # GeneratorExit results from exception in caller + # (on garbage collection, .close() is called, see PEP342) + except (Exception, GeneratorExit) as e: + self.log.fatal(f"Target Mailbox ({target}) killed, exception {type(e)}, message {e}") + if isinstance(e, strax.MailboxKilled): + _, exc, traceback = reason = e.args[0] + else: + exc = e + reason = (e.__class__, e, sys.exc_info()[2]) + traceback = reason[2] + + # We will reraise it in just a moment... + + if exc is not None: + if isinstance(exc, GeneratorExit): + print("Main generator exited irregularly?!") + reason[2] = ( + "Hm, interesting. Most likely an exception was thrown " + "outside strax, but we did not handle it properly." + ) + + # Kill the mailboxes + for m in self.mailboxes.values(): + if m != target: + self.log.debug(f"Killing {m}") + m.kill(upstream=True, reason=reason) + + self.log.debug("Closing threads") + for m in self.mailboxes.values(): + m.cleanup() + self.log.debug("Closing threads completed") + + self.log.debug("Closing executors") + if self.thread_executor is not None: + self.thread_executor.shutdown(wait=True) + if self.process_executor not in [None, self.thread_executor]: + self.process_executor.shutdown(wait=True) + self.log.debug("Closing executors completed") + + if exc is not None: + # Reraise exception. This is outside the except block + # to avoid the 'during handling of this exception, another + # exception occurred' stuff from confusing the traceback + # which is printed for the user + self.log.debug("Reraising exception") + raise exc.with_traceback(traceback) + + # Check the savers for any exception that occurred during saving + # These are thrown back to the mailbox, but if that has already closed + # it doesn't trigger a crash... + # TODO: add savers inlined by parallelsourceplugin + # TODO: need to look at plugins too if we ever implement true + # multi-target mode + for k, saver_list in self.components.savers.items(): + for s in saver_list: + if s.got_exception: + self.log.fatal(f"Caught error while saving {k}!") + raise s.got_exception + + self.log.debug("Processing finished") diff --git a/strax/storage/common.py b/strax/storage/common.py index 36328969..f7dbda43 100644 --- a/strax/storage/common.py +++ b/strax/storage/common.py @@ -635,36 +635,22 @@ def save_from(self, source: typing.Generator, rechunk=True, executor=None): exhausted = False chunk_i = 0 - run_id = self.md["run_id"] - _is_superrun = run_id.startswith("_") and not run_id.startswith("__") + rechunker = strax.Rechunker( + rechunk=rechunk and self.allow_rechunk, run_id=self.md["run_id"] + ) + try: while not exhausted: chunk = None try: - if rechunk and self.allow_rechunk: - while chunk is None or chunk.data.nbytes < chunk.target_size_mb * 1e6: - next_chunk = next(source) - - if _is_superrun: - # If we are creating a superrun, we load data from subruns - # and the loaded subrun chunk becomes a superun chunk: - next_chunk = strax.transform_chunk_to_superrun_chunk( - run_id, next_chunk - ) - chunk = strax.Chunk.concatenate([chunk, next_chunk]) - else: - chunk = next(source) - if _is_superrun: - # If we are creating a superrun, we load data from subruns - # and the loaded subrun chunk becomes a superun chunk: - chunk = strax.transform_chunk_to_superrun_chunk(run_id, chunk) - + chunk = rechunker.receive(next(source)) except StopIteration: exhausted = True + chunk = rechunker.flush() if chunk is None: - break + continue new_f = self.save(chunk=chunk, chunk_i=chunk_i, executor=executor) pending = [f for f in pending if not f.done()] diff --git a/strax/storage/file_rechunker.py b/strax/storage/file_rechunker.py index 936d44d1..250f389e 100644 --- a/strax/storage/file_rechunker.py +++ b/strax/storage/file_rechunker.py @@ -189,13 +189,12 @@ def _get_meta_data_and_compressor(backend, source_directory, compressor, target_ def _get_executor(parallel, max_workers): - # nested import - prevent circular imports - from strax.processor import SHMExecutor - return { True: ThreadPoolExecutor(max_workers), "thread": ThreadPoolExecutor(max_workers), "process": ( - ProcessPoolExecutor(max_workers) if SHMExecutor is None else SHMExecutor(max_workers) + ProcessPoolExecutor(max_workers) + if strax.SHMExecutor is None + else strax.SHMExecutor(max_workers) ), }.get(parallel) diff --git a/tests/test_core.py b/tests/test_core.py index 7bbc0a18..add6d36f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -7,20 +7,29 @@ from strax.testutils import * - -def test_core(): - for allow_multiprocess in (False, True): - for max_workers in [1, 2]: - mystrax = strax.Context( - storage=[], - register=[Records, Peaks], - allow_multiprocess=allow_multiprocess, - use_per_run_defaults=True, - ) - bla = mystrax.get_array(run_id=run_id, targets="peaks", max_workers=max_workers) - p = mystrax.get_single_plugin(run_id, "records") - assert len(bla) == p.config["recs_per_chunk"] * p.config["n_chunks"] - assert bla.dtype == strax.peak_dtype() +processing_conditions = pytest.mark.parametrize( + "allow_multiprocess,max_workers,processor", + [ + (False, 1, "threaded_mailbox"), + (True, 2, "threaded_mailbox"), + (False, 1, "single_thread"), + ], +) + + +@processing_conditions +def test_core(allow_multiprocess, max_workers, processor): + mystrax = strax.Context( + storage=[], + register=[Records, Peaks], + processors=[processor], + allow_multiprocess=allow_multiprocess, + use_per_run_defaults=True, + ) + bla = mystrax.get_array(run_id=run_id, targets="peaks", max_workers=max_workers) + p = mystrax.get_single_plugin(run_id, "records") + assert len(bla) == p.config["recs_per_chunk"] * p.config["n_chunks"] + assert bla.dtype == strax.peak_dtype() def test_multirun(): @@ -37,11 +46,14 @@ def test_multirun(): np.testing.assert_equal(bla["run_id"], np.array(["0"] * n + ["1"] * n)) -def test_filestore(): +@processing_conditions +def test_filestore(allow_multiprocess, max_workers, processor): with tempfile.TemporaryDirectory() as temp_dir: mystrax = strax.Context( storage=strax.DataDirectory(temp_dir, deep_scan=True), register=[Records, Peaks], + processors=[processor], + allow_multiprocess=allow_multiprocess, use_per_run_defaults=True, ) @@ -197,30 +209,31 @@ def test_storage_converter(): store_2.find(key) -def test_exception(): - for allow_multiprocess, max_workers in zip((False, True), (1, 2)): - with tempfile.TemporaryDirectory() as temp_dir: - st = strax.Context( - storage=strax.DataDirectory(temp_dir), - register=[Records, Peaks], - allow_multiprocess=allow_multiprocess, - config=dict(crash=True), - use_per_run_defaults=True, - ) +@processing_conditions +def test_exception(allow_multiprocess, max_workers, processor): + with tempfile.TemporaryDirectory() as temp_dir: + st = strax.Context( + storage=strax.DataDirectory(temp_dir), + register=[Records, Peaks], + processors=[processor], + allow_multiprocess=allow_multiprocess, + config=dict(crash=True), + use_per_run_defaults=True, + ) - # Check correct exception is thrown - with pytest.raises(SomeCrash): - st.make(run_id=run_id, targets="peaks", max_workers=max_workers) + # Check correct exception is thrown + with pytest.raises(SomeCrash): + st.make(run_id=run_id, targets="peaks", max_workers=max_workers) - # Check exception is recorded in metadata - # in both its original data type and dependents - for target in ("peaks", "records"): - assert "SomeCrash" in st.get_meta(run_id, target)["exception"] + # Check exception is recorded in metadata + # in both its original data type and dependents + for target in ("peaks", "records"): + assert "SomeCrash" in st.get_meta(run_id, target)["exception"] - # Check corrupted data does not load - st.context_config["forbid_creation_of"] = ("peaks",) - with pytest.raises(strax.DataNotAvailable): - st.get_df(run_id=run_id, targets="peaks", max_workers=max_workers) + # Check corrupted data does not load + st.context_config["forbid_creation_of"] = ("peaks",) + with pytest.raises(strax.DataNotAvailable): + st.get_df(run_id=run_id, targets="peaks", max_workers=max_workers) def test_exception_in_saver(caplog):