diff --git a/strax/context.py b/strax/context.py index c3eddb706..747447a5e 100644 --- a/strax/context.py +++ b/strax/context.py @@ -683,7 +683,11 @@ def _fix_dependency(self, plugin_registry: dict, end_plugin: str): self._fix_dependency(plugin_registry, go_to) plugin_registry[end_plugin].fix_dtype() - def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]: + def __get_requested_plugins_from_cache( + self, + run_id: str, + targets: ty.Tuple[str], + ) -> ty.Dict[str, strax.Plugin]: # Doubly underscored since we don't do any key-checks etc here """Load requested plugins from the plugin_cache.""" requested_plugins = {} @@ -709,15 +713,20 @@ def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]: plugin.deps = { dependency: requested_plugins[dependency] for dependency in plugin.depends_on } + # Finally, fix the dtype. Since infer_dtype may depend on the # entire deps chain, we need to start at the last plugin and go # all the way down to the lowest level. - for final_plugins in self._get_end_targets(requested_plugins): - self._fix_dependency(requested_plugins, final_plugins) + for target_plugins in targets: + self._fix_dependency(requested_plugins, target_plugins) + + requested_plugins = {i: v for i, v in requested_plugins.items() if i in targets} return requested_plugins def _get_plugins( - self, targets: ty.Union[ty.Tuple[str], ty.List[str]], run_id: str + self, + targets: ty.Union[ty.Tuple[str], ty.List[str]], + run_id: str, ) -> ty.Dict[str, strax.Plugin]: """Return dictionary of plugin instances necessary to compute targets from scratch. @@ -725,21 +734,6 @@ def _get_plugins( referenced under multiple keys in the output dict. """ - if self._plugins_are_cached(targets): - cached_plugins = self.__get_plugins_from_cache(run_id) - plugins = {} - targets = list(targets) - while targets: - target = targets.pop(0) - if target in plugins: - continue - - target_plugin = cached_plugins[target] - for provides in target_plugin.provides: - plugins[provides] = target_plugin - targets += list(target_plugin.depends_on) - return plugins - # Check all config options are taken by some registered plugin class # (helps spot typos) all_opts = set().union(*[ @@ -749,98 +743,126 @@ def _get_plugins( if not (k in all_opts or k in self.context_config["free_options"]): self.log.warning(f"Option {k} not taken by any registered plugin") - # Initialize plugins for the entire computation graph - # (most likely far further down than we need) - # to get lineages and dependency info. - def get_plugin(data_type): - nonlocal non_local_plugins + plugins = {} + targets = list(targets) + safety_counter = 0 + while targets and safety_counter < 10_000: + safety_counter += 1 + targets = list(set(targets)) # Remove duplicates from list. + target = targets.pop(0) + if target in plugins: + continue - if data_type not in self._plugin_class_registry: - raise KeyError(f"No plugin class registered that provides {data_type}") + target_plugin = self.__get_plugin(run_id, target) + for provides in target_plugin.provides: + plugins[provides] = target_plugin + targets += list(target_plugin.depends_on) - plugin = self._plugin_class_registry[data_type]() + _not_all_plugins_initalized = (safety_counter == 10_000) & len(targets) + if _not_all_plugins_initalized: + raise ValueError( + "Could not initalize all plugins to compute target from scratch. " + f"The reamining targets missing are: {targets}" + ) - d_provides = None # just to make codefactor happy - for d_provides in plugin.provides: - non_local_plugins[d_provides] = plugin + return plugins - plugin.run_id = run_id + def __get_plugin(self, run_id: str, data_type: str): + """Get single plugin either from cache or initialize it.""" + # Check if plugin for data_type is already cached + if self._plugins_are_cached((data_type,)): + cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,)) + target_plugin = cached_plugins[data_type] + return target_plugin - # The plugin may not get all the required options here - # but we don't know if we need the plugin yet - self._set_plugin_config(plugin, run_id, tolerant=True) + if data_type not in self._plugin_class_registry: + raise KeyError(f"No plugin class registered that provides {data_type}") - plugin.deps = {d_depends: get_plugin(d_depends) for d_depends in plugin.depends_on} + plugin = self._plugin_class_registry[data_type]() - last_provide = d_provides + plugin.run_id = run_id - if plugin.child_plugin: - # Plugin is a child of another plugin, hence we have to - # drop the parents config from the lineage - configs = {} + # The plugin may not get all the required options here + # but we don't know if we need the plugin yet + self._set_plugin_config(plugin, run_id, tolerant=True) - # Getting information about the parent: - parent_class = plugin.__class__.__bases__[0] - # Get all parent options which are overwritten by a child: - parent_options = [ - option.parent_option_name - for option in plugin.takes_config.values() - if option.child_option - ] + plugin.deps = { + d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on + } - for option_name, v in plugin.config.items(): - # Looping over all settings, option_name is either the option name of the - # parent or the child. - if option_name in parent_options: - # In case it is the parent we continue - continue + self.__add_lineage_to_plugin(run_id, plugin) - if plugin.takes_config[option_name].track: - # Add all options which should be tracked: - configs[option_name] = v + if not hasattr(plugin, "data_kind") and not plugin.multi_output: + if len(plugin.depends_on): + # Assume data kind is the same as the first dependency + first_dep = plugin.depends_on[0] + plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep) + else: + # No dependencies: assume provided data kind and + # data type are synonymous + plugin.data_kind = plugin.provides[0] - # Also adding name and version of the parent to the lineage: - configs[parent_class.__name__] = parent_class.__version__ + plugin.fix_dtype() - plugin.lineage = { - last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs) - } - else: - plugin.lineage = { - last_provide: ( - plugin.__class__.__name__, - plugin.version(run_id), - { - option: setting - for option, setting in plugin.config.items() - if plugin.takes_config[option].track - }, - ) - } - for d_depends in plugin.depends_on: - plugin.lineage.update(plugin.deps[d_depends].lineage) - - if not hasattr(plugin, "data_kind") and not plugin.multi_output: - if len(plugin.depends_on): - # Assume data kind is the same as the first dependency - first_dep = plugin.depends_on[0] - plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep) - else: - # No dependencies: assume provided data kind and - # data type are synonymous - plugin.data_kind = plugin.provides[0] + # Add plugin to cache + self._plugins_to_cache({data_type: plugin for data_type in plugin.provides}) - plugin.fix_dtype() + return plugin - return plugin + def __add_lineage_to_plugin(self, run_id, plugin): + """Adds lineage to plugin in place. - non_local_plugins = {} - for t in targets: - p = get_plugin(t) - non_local_plugins[t] = p + Also adds parent infromation in case of a child plugin. + + """ + last_provide = [d_provides for d_provides in plugin.provides][-1] + + if plugin.child_plugin: + # Plugin is a child of another plugin, hence we have to + # drop the parents config from the lineage + configs = {} + + # Getting information about the parent: + parent_class = plugin.__class__.__bases__[0] + # Get all parent options which are overwritten by a child: + parent_options = [ + option.parent_option_name + for option in plugin.takes_config.values() + if option.child_option + ] + + for option_name, v in plugin.config.items(): + # Looping over all settings, option_name is either the option name of the + # parent or the child. + if option_name in parent_options: + # In case it is the parent we continue + continue + + if plugin.takes_config[option_name].track: + # Add all options which should be tracked: + configs[option_name] = v + + # Also adding name and version of the parent to the lineage: + configs[parent_class.__name__] = parent_class.__version__ + + plugin.lineage = { + last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs) + } + else: + plugin.lineage = { + last_provide: ( + plugin.__class__.__name__, + plugin.version(run_id), + { + option: setting + for option, setting in plugin.config.items() + if plugin.takes_config[option].track + }, + ) + } - self._plugins_to_cache(non_local_plugins) - return non_local_plugins + for d_depends in plugin.depends_on: + plugin.lineage.update(plugin.deps[d_depends].lineage) def _per_run_default_allowed_check(self, option_name, option): """Check if an option of a registered plugin is allowed.""" diff --git a/strax/plugins/__init__.py b/strax/plugins/__init__.py index ada9e1ef0..ffbb5b014 100644 --- a/strax/plugins/__init__.py +++ b/strax/plugins/__init__.py @@ -4,3 +4,4 @@ from .merge_only_plugin import * from .overlap_window_plugin import * from .parrallel_source_plugin import * +from .down_chunking_plugin import * diff --git a/strax/plugins/down_chunking_plugin.py b/strax/plugins/down_chunking_plugin.py new file mode 100644 index 000000000..fb34e2b29 --- /dev/null +++ b/strax/plugins/down_chunking_plugin.py @@ -0,0 +1,44 @@ +import strax +from .plugin import Plugin + +export, __all__ = strax.exporter() + + +## +# Plugin which allows to use yield in plugins compute method. +# Allows to chunk down output before storing to disk. +# Only works if multiprocessing is omitted. +## + + +@export +class DownChunkingPlugin(Plugin): + """Plugin that merges data from its dependencies.""" + + parallel = False + + def __init__(self): + super().__init__() + + if self.parallel: + raise NotImplementedError( + f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which ' + "currently does not support parallel processing." + ) + + if self.multi_output: + raise NotImplementedError( + f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which ' + "currently does not support multiple outputs. Please only provide " + "a single data-type." + ) + + def iter(self, iters, executor=None): + return super().iter(iters, executor) + + def _iter_compute(self, chunk_i, **inputs_merged): + return self.do_compute(chunk_i=chunk_i, **inputs_merged) + + def _fix_output(self, result, start, end, _dtype=None): + """Wrapper around _fix_output to support the return of iterators.""" + return result diff --git a/strax/plugins/plugin.py b/strax/plugins/plugin.py index ee95d22c6..afaa9d161 100644 --- a/strax/plugins/plugin.py +++ b/strax/plugins/plugin.py @@ -133,6 +133,8 @@ def __init__(self): # not have to updated save_when self.save_when = immutabledict.fromkeys(self.provides, self.save_when) + if getattr(self, "provides", None): + self.provides = strax.to_str_tuple(self.provides) self.compute_pars = compute_pars self.input_buffer = dict() @@ -492,7 +494,7 @@ class IterDone(Exception): pending_futures = [f for f in pending_futures if not f.done()] yield new_future else: - yield self.do_compute(chunk_i=chunk_i, **inputs_merged) + yield from self._iter_compute(chunk_i=chunk_i, **inputs_merged) except IterDone: # Check all sources are exhausted. @@ -517,6 +519,10 @@ class IterDone(Exception): finally: self.cleanup(wait_for=pending_futures) + def _iter_compute(self, chunk_i, **inputs_merged): + """Either yields or returns strax chunks from the input.""" + yield self.do_compute(chunk_i=chunk_i, **inputs_merged) + def cleanup(self, wait_for): pass # A standard plugin doesn't need to do anything here diff --git a/strax/processing/general.py b/strax/processing/general.py index 2c796ea2f..ddf512cdc 100644 --- a/strax/processing/general.py +++ b/strax/processing/general.py @@ -1,4 +1,3 @@ -import os import warnings warnings.simplefilter("always", UserWarning) @@ -58,29 +57,25 @@ def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind="merge return x[sort_i] -# Getting endtime jitted is a bit awkward, especially since it has to -# keep working with NUMBA_DISABLE_JIT, which we use for coverage tests. -# See https://github.com/numba/numba/issues/4759 -if os.environ.get("NUMBA_DISABLE_JIT"): - - @export - def endtime(x): - """Return endtime of intervals x.""" - if "endtime" in x.dtype.fields: - return x["endtime"] - else: - return x["time"] + x["length"] * x["dt"] - -else: - - @export - @numba.generated_jit(nopython=True, nogil=True) - def endtime(x): - """Return endtime of intervals x.""" - if "endtime" in x.dtype.fields: - return lambda x: x["endtime"] - else: - return lambda x: x["time"] + x["length"] * x["dt"] +@export +def endtime(x): + """Return endtime of intervals x.""" + if "endtime" in x.dtype.fields: + return x["endtime"] + else: + return x["time"] + x["length"] * x["dt"] + + +# Jitting endtime needs special attention, since inspecting the dtype +# has to happen in the python layer. +# (Used to work through numba.generated_jit, now numba.extending.overload) +@numba.extending.overload(endtime) +def _overload_endtime(x): + """Return endtime of intervals x.""" + if "endtime" in x.dtype.fields: + return lambda x: x["endtime"] + else: + return lambda x: x["time"] + x["length"] * x["dt"] @export diff --git a/strax/testutils.py b/strax/testutils.py index 1d2f25e72..eea791e63 100644 --- a/strax/testutils.py +++ b/strax/testutils.py @@ -252,6 +252,58 @@ def compute(self, peaks): return dict(peak_classification=p, lone_hits=lh) +# Plugins with time structure within chunks, +# used to test down chunking within plugin compute. +class RecordsWithTimeStructure(Records): + """Same as Records but with some structure in "time" for testing.""" + + def setup(self): + self.last_end = 0 + + def compute(self, chunk_i): + r = np.zeros(self.config["recs_per_chunk"], self.dtype) + r["time"] = self.last_end + np.arange(self.config["recs_per_chunk"]) + 5 + r["length"] = r["dt"] = 1 + r["channel"] = np.arange(len(r)) + + end = self.last_end + self.config["recs_per_chunk"] + 10 + chunk = self.chunk(start=self.last_end, end=end, data=r) + self.last_end = end + + return chunk + + +class DownSampleRecords(strax.DownChunkingPlugin): + """PLugin to test the downsampling of Chunks during compute. + + Needed for simulations. + + """ + + provides = "records_down_chunked" + depends_on = "records" + dtype = strax.record_dtype() + rechunk_on_save = False + + def compute(self, records, start, end): + offset = 0 + last_start = start + + count = 0 + for count, r in enumerate(records): + if count == 5: + res = records[offset:count] + chunk_end = np.max(strax.endtime(res)) + offset = count + chunk = self.chunk(start=last_start, end=chunk_end, data=res) + last_start = chunk_end + yield chunk + + res = records[offset : count + 1] + chunk = self.chunk(start=last_start, end=end, data=res) + yield chunk + + # Used in test_core.py run_id = "0" diff --git a/tests/test_context.py b/tests/test_context.py index 7602277fa..067aeed17 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,5 +1,11 @@ import strax -from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id +from strax.testutils import ( + Records, + Peaks, + PeaksWoPerRunDefault, + PeakClassification, + run_id, +) import tempfile import numpy as np from hypothesis import given, settings @@ -301,10 +307,10 @@ def test_deregister(self): st.deregister_plugins_with_missing_dependencies() assert st._plugin_class_registry.pop("peaks", None) is None - def get_context(self, use_defaults): + def get_context(self, use_defaults, **kwargs): """Get simple context where we have one mock run in the only storage frontend.""" assert isinstance(use_defaults, bool) - st = strax.Context(storage=self.get_mock_sf(), check_available=("records",)) + st = strax.Context(storage=self.get_mock_sf(), check_available=("records",), **kwargs) st.set_context_config({"use_per_run_defaults": use_defaults}) return st diff --git a/tests/test_down_chunk_plugin.py b/tests/test_down_chunk_plugin.py new file mode 100644 index 000000000..f0d3d8a90 --- /dev/null +++ b/tests/test_down_chunk_plugin.py @@ -0,0 +1,69 @@ +from strax.testutils import RecordsWithTimeStructure, DownSampleRecords, run_id +import strax +import numpy as np + +import os +import tempfile +import shutil +import uuid +import unittest + + +class TestContext(unittest.TestCase): + """Tests for DownChunkPlugin class.""" + + def setUp(self): + """Make temp folder to write data to.""" + temp_folder = uuid.uuid4().hex + self.tempdir = os.path.join(tempfile.gettempdir(), temp_folder) + assert not os.path.exists(self.tempdir) + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_down_chunking(self): + st = self.get_context() + st.register(RecordsWithTimeStructure) + st.register(DownSampleRecords) + + st.make(run_id, "records") + st.make(run_id, "records_down_chunked") + + chunks_records = st.get_meta(run_id, "records")["chunks"] + chunks_records_down_chunked = st.get_meta(run_id, "records_down_chunked")["chunks"] + + _chunks_are_downsampled = len(chunks_records) * 2 == len(chunks_records_down_chunked) + assert _chunks_are_downsampled + + _chunks_are_continues = np.all([ + chunks_records_down_chunked[i]["end"] == chunks_records_down_chunked[i + 1]["start"] + for i in range(len(chunks_records_down_chunked) - 1) + ]) + assert _chunks_are_continues + + def test_down_chunking_multi_processing(self): + st = self.get_context(allow_multiprocess=True) + st.register(RecordsWithTimeStructure) + st.register(DownSampleRecords) + + st.make(run_id, "records", max_workers=1) + + class TestMultiProcessing(DownSampleRecords): + parallel = True + + st.register(TestMultiProcessing) + with self.assertRaises(NotImplementedError): + st.make(run_id, "records_down_chunked", max_workers=2) + + def get_context(self, **kwargs): + """Simple context to run tests.""" + st = strax.Context(storage=self.get_mock_sf(), check_available=("records",), **kwargs) + return st + + def get_mock_sf(self): + mock_rundb = [{"name": "0", strax.RUN_DEFAULTS_KEY: dict(base_area=43)}] + sf = strax.DataDirectory(path=self.tempdir, deep_scan=True, provide_run_metadata=True) + for d in mock_rundb: + sf.write_run_metadata(d["name"], d) + return sf