Skip to content

Commit

Permalink
Merge branch 'add_warning' of github.com:AxFoundation/strax into add_…
Browse files Browse the repository at this point in the history
…warning
  • Loading branch information
WenzDaniel committed Dec 14, 2023
2 parents b550058 + 2b45413 commit 0ab760e
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 122 deletions.
210 changes: 116 additions & 94 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,11 @@ def _fix_dependency(self, plugin_registry: dict, end_plugin: str):
self._fix_dependency(plugin_registry, go_to)
plugin_registry[end_plugin].fix_dtype()

def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]:
def __get_requested_plugins_from_cache(
self,
run_id: str,
targets: ty.Tuple[str],
) -> ty.Dict[str, strax.Plugin]:
# Doubly underscored since we don't do any key-checks etc here
"""Load requested plugins from the plugin_cache."""
requested_plugins = {}
Expand All @@ -709,37 +713,27 @@ def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]:
plugin.deps = {
dependency: requested_plugins[dependency] for dependency in plugin.depends_on
}

# Finally, fix the dtype. Since infer_dtype may depend on the
# entire deps chain, we need to start at the last plugin and go
# all the way down to the lowest level.
for final_plugins in self._get_end_targets(requested_plugins):
self._fix_dependency(requested_plugins, final_plugins)
for target_plugins in targets:
self._fix_dependency(requested_plugins, target_plugins)

requested_plugins = {i: v for i, v in requested_plugins.items() if i in targets}
return requested_plugins

def _get_plugins(
self, targets: ty.Union[ty.Tuple[str], ty.List[str]], run_id: str
self,
targets: ty.Union[ty.Tuple[str], ty.List[str]],
run_id: str,
) -> ty.Dict[str, strax.Plugin]:
"""Return dictionary of plugin instances necessary to compute targets from scratch.
For a plugin that produces multiple outputs, we make only a single instance, which is
referenced under multiple keys in the output dict.
"""
if self._plugins_are_cached(targets):
cached_plugins = self.__get_plugins_from_cache(run_id)
plugins = {}
targets = list(targets)
while targets:
target = targets.pop(0)
if target in plugins:
continue

target_plugin = cached_plugins[target]
for provides in target_plugin.provides:
plugins[provides] = target_plugin
targets += list(target_plugin.depends_on)
return plugins

# Check all config options are taken by some registered plugin class
# (helps spot typos)
all_opts = set().union(*[
Expand All @@ -749,98 +743,126 @@ def _get_plugins(
if not (k in all_opts or k in self.context_config["free_options"]):
self.log.warning(f"Option {k} not taken by any registered plugin")

# Initialize plugins for the entire computation graph
# (most likely far further down than we need)
# to get lineages and dependency info.
def get_plugin(data_type):
nonlocal non_local_plugins
plugins = {}
targets = list(targets)
safety_counter = 0
while targets and safety_counter < 10_000:
safety_counter += 1
targets = list(set(targets)) # Remove duplicates from list.
target = targets.pop(0)
if target in plugins:
continue

if data_type not in self._plugin_class_registry:
raise KeyError(f"No plugin class registered that provides {data_type}")
target_plugin = self.__get_plugin(run_id, target)
for provides in target_plugin.provides:
plugins[provides] = target_plugin
targets += list(target_plugin.depends_on)

plugin = self._plugin_class_registry[data_type]()
_not_all_plugins_initalized = (safety_counter == 10_000) & len(targets)
if _not_all_plugins_initalized:
raise ValueError(
"Could not initalize all plugins to compute target from scratch. "
f"The reamining targets missing are: {targets}"
)

d_provides = None # just to make codefactor happy
for d_provides in plugin.provides:
non_local_plugins[d_provides] = plugin
return plugins

plugin.run_id = run_id
def __get_plugin(self, run_id: str, data_type: str):
"""Get single plugin either from cache or initialize it."""
# Check if plugin for data_type is already cached
if self._plugins_are_cached((data_type,)):
cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,))
target_plugin = cached_plugins[data_type]
return target_plugin

# The plugin may not get all the required options here
# but we don't know if we need the plugin yet
self._set_plugin_config(plugin, run_id, tolerant=True)
if data_type not in self._plugin_class_registry:
raise KeyError(f"No plugin class registered that provides {data_type}")

plugin.deps = {d_depends: get_plugin(d_depends) for d_depends in plugin.depends_on}
plugin = self._plugin_class_registry[data_type]()

last_provide = d_provides
plugin.run_id = run_id

if plugin.child_plugin:
# Plugin is a child of another plugin, hence we have to
# drop the parents config from the lineage
configs = {}
# The plugin may not get all the required options here
# but we don't know if we need the plugin yet
self._set_plugin_config(plugin, run_id, tolerant=True)

# Getting information about the parent:
parent_class = plugin.__class__.__bases__[0]
# Get all parent options which are overwritten by a child:
parent_options = [
option.parent_option_name
for option in plugin.takes_config.values()
if option.child_option
]
plugin.deps = {
d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on
}

for option_name, v in plugin.config.items():
# Looping over all settings, option_name is either the option name of the
# parent or the child.
if option_name in parent_options:
# In case it is the parent we continue
continue
self.__add_lineage_to_plugin(run_id, plugin)

if plugin.takes_config[option_name].track:
# Add all options which should be tracked:
configs[option_name] = v
if not hasattr(plugin, "data_kind") and not plugin.multi_output:
if len(plugin.depends_on):
# Assume data kind is the same as the first dependency
first_dep = plugin.depends_on[0]
plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep)
else:
# No dependencies: assume provided data kind and
# data type are synonymous
plugin.data_kind = plugin.provides[0]

# Also adding name and version of the parent to the lineage:
configs[parent_class.__name__] = parent_class.__version__
plugin.fix_dtype()

plugin.lineage = {
last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs)
}
else:
plugin.lineage = {
last_provide: (
plugin.__class__.__name__,
plugin.version(run_id),
{
option: setting
for option, setting in plugin.config.items()
if plugin.takes_config[option].track
},
)
}
for d_depends in plugin.depends_on:
plugin.lineage.update(plugin.deps[d_depends].lineage)

if not hasattr(plugin, "data_kind") and not plugin.multi_output:
if len(plugin.depends_on):
# Assume data kind is the same as the first dependency
first_dep = plugin.depends_on[0]
plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep)
else:
# No dependencies: assume provided data kind and
# data type are synonymous
plugin.data_kind = plugin.provides[0]
# Add plugin to cache
self._plugins_to_cache({data_type: plugin for data_type in plugin.provides})

plugin.fix_dtype()
return plugin

return plugin
def __add_lineage_to_plugin(self, run_id, plugin):
"""Adds lineage to plugin in place.
non_local_plugins = {}
for t in targets:
p = get_plugin(t)
non_local_plugins[t] = p
Also adds parent infromation in case of a child plugin.
"""
last_provide = [d_provides for d_provides in plugin.provides][-1]

if plugin.child_plugin:
# Plugin is a child of another plugin, hence we have to
# drop the parents config from the lineage
configs = {}

# Getting information about the parent:
parent_class = plugin.__class__.__bases__[0]
# Get all parent options which are overwritten by a child:
parent_options = [
option.parent_option_name
for option in plugin.takes_config.values()
if option.child_option
]

for option_name, v in plugin.config.items():
# Looping over all settings, option_name is either the option name of the
# parent or the child.
if option_name in parent_options:
# In case it is the parent we continue
continue

if plugin.takes_config[option_name].track:
# Add all options which should be tracked:
configs[option_name] = v

# Also adding name and version of the parent to the lineage:
configs[parent_class.__name__] = parent_class.__version__

plugin.lineage = {
last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs)
}
else:
plugin.lineage = {
last_provide: (
plugin.__class__.__name__,
plugin.version(run_id),
{
option: setting
for option, setting in plugin.config.items()
if plugin.takes_config[option].track
},
)
}

self._plugins_to_cache(non_local_plugins)
return non_local_plugins
for d_depends in plugin.depends_on:
plugin.lineage.update(plugin.deps[d_depends].lineage)

def _per_run_default_allowed_check(self, option_name, option):
"""Check if an option of a registered plugin is allowed."""
Expand Down
1 change: 1 addition & 0 deletions strax/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .merge_only_plugin import *
from .overlap_window_plugin import *
from .parrallel_source_plugin import *
from .down_chunking_plugin import *
44 changes: 44 additions & 0 deletions strax/plugins/down_chunking_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import strax
from .plugin import Plugin

export, __all__ = strax.exporter()


##
# Plugin which allows to use yield in plugins compute method.
# Allows to chunk down output before storing to disk.
# Only works if multiprocessing is omitted.
##


@export
class DownChunkingPlugin(Plugin):
"""Plugin that merges data from its dependencies."""

parallel = False

def __init__(self):
super().__init__()

if self.parallel:
raise NotImplementedError(
f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which '
"currently does not support parallel processing."
)

if self.multi_output:
raise NotImplementedError(
f'Plugin "{self.__class__.__name__}" is a DownChunkingPlugin which '
"currently does not support multiple outputs. Please only provide "
"a single data-type."
)

def iter(self, iters, executor=None):
return super().iter(iters, executor)

def _iter_compute(self, chunk_i, **inputs_merged):
return self.do_compute(chunk_i=chunk_i, **inputs_merged)

def _fix_output(self, result, start, end, _dtype=None):
"""Wrapper around _fix_output to support the return of iterators."""
return result
8 changes: 7 additions & 1 deletion strax/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(self):
# not have to updated save_when
self.save_when = immutabledict.fromkeys(self.provides, self.save_when)

if getattr(self, "provides", None):
self.provides = strax.to_str_tuple(self.provides)
self.compute_pars = compute_pars
self.input_buffer = dict()

Expand Down Expand Up @@ -492,7 +494,7 @@ class IterDone(Exception):
pending_futures = [f for f in pending_futures if not f.done()]
yield new_future
else:
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)
yield from self._iter_compute(chunk_i=chunk_i, **inputs_merged)

except IterDone:
# Check all sources are exhausted.
Expand All @@ -517,6 +519,10 @@ class IterDone(Exception):
finally:
self.cleanup(wait_for=pending_futures)

def _iter_compute(self, chunk_i, **inputs_merged):
"""Either yields or returns strax chunks from the input."""
yield self.do_compute(chunk_i=chunk_i, **inputs_merged)

def cleanup(self, wait_for):
pass
# A standard plugin doesn't need to do anything here
Expand Down
43 changes: 19 additions & 24 deletions strax/processing/general.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import warnings

warnings.simplefilter("always", UserWarning)
Expand Down Expand Up @@ -58,29 +57,25 @@ def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind="merge
return x[sort_i]


# Getting endtime jitted is a bit awkward, especially since it has to
# keep working with NUMBA_DISABLE_JIT, which we use for coverage tests.
# See https://github.com/numba/numba/issues/4759
if os.environ.get("NUMBA_DISABLE_JIT"):

@export
def endtime(x):
"""Return endtime of intervals x."""
if "endtime" in x.dtype.fields:
return x["endtime"]
else:
return x["time"] + x["length"] * x["dt"]

else:

@export
@numba.generated_jit(nopython=True, nogil=True)
def endtime(x):
"""Return endtime of intervals x."""
if "endtime" in x.dtype.fields:
return lambda x: x["endtime"]
else:
return lambda x: x["time"] + x["length"] * x["dt"]
@export
def endtime(x):
"""Return endtime of intervals x."""
if "endtime" in x.dtype.fields:
return x["endtime"]
else:
return x["time"] + x["length"] * x["dt"]


# Jitting endtime needs special attention, since inspecting the dtype
# has to happen in the python layer.
# (Used to work through numba.generated_jit, now numba.extending.overload)
@numba.extending.overload(endtime)
def _overload_endtime(x):
"""Return endtime of intervals x."""
if "endtime" in x.dtype.fields:
return lambda x: x["endtime"]
else:
return lambda x: x["time"] + x["length"] * x["dt"]


@export
Expand Down
Loading

0 comments on commit 0ab760e

Please sign in to comment.