Skip to content

Commit

Permalink
Add function of dependency level of data_types (#896)
Browse files Browse the repository at this point in the history
* Add function of dependency level of `data_types`

* Cache results of tree_levels

* Minor simplify

* Minor change

* Add `get_dependency_plugins` to get the plugins of dependencies
initialized
  • Loading branch information
dachengx authored Oct 7, 2024
1 parent ef2861e commit ad1c13e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 22 deletions.
102 changes: 82 additions & 20 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class Context:

runs: ty.Optional[pd.DataFrame] = None
_fixed_plugin_cache: ty.Optional[dict] = None
_fixed_level_cache: ty.Optional[dict] = None
_run_defaults_cache: dict
storage: ty.List[strax.StorageFrontend]

Expand Down Expand Up @@ -554,7 +555,7 @@ def show_config(self, data_type=None, pattern="*", run_id="9" * 20):
:param data_type: Data type name
:param pattern: Show only options that match (fnmatch) pattern
:param run_id: Run id to use for run-dependent config options. If omitted, will show
:param run_id: run id to use for run-dependent config options. If omitted, will show
defaults active for new runs.
"""
Expand Down Expand Up @@ -1924,7 +1925,7 @@ def get_zarr(
cannot fit in memory zarr is very compatible with dask. Targets are loaded into separate
arrays and runs are merged. the data is added to any existing data in the storage location.
:param run_ids: (Iterable) Run ids you wish to load.
:param run_ids: (Iterable) run ids you wish to load.
:param targets: (Iterable) targets to load.
:param storage: (str, optional) fsspec path to store array. Defaults to './strax_temp_data'.
:param overwrite: (boolean, optional) whether to overwrite existing arrays for targets at
Expand Down Expand Up @@ -2395,7 +2396,7 @@ def merge_per_chunk_storage(

chunks = self.get_metadata(run_id, per_chunked_dependency)["chunks"]
if chunk_number_group is not None:
combined_chunk_numbers = list(itertools.chain(*chunk_number_group))
combined_chunk_numbers = list(itertools.chain.from_iterable(chunk_number_group))
if len(combined_chunk_numbers) != len(set(combined_chunk_numbers)):
raise ValueError(f"Duplicate chunk numbers found in {chunk_number_group}")
if min(combined_chunk_numbers) == 0 and max(combined_chunk_numbers) == len(chunks) - 1:
Expand Down Expand Up @@ -2534,16 +2535,14 @@ def stored_dependencies(
if target in _targets_stored:
return None

this_target_is_stored = self.is_stored(run_id, target)
_targets_stored[target] = this_target_is_stored
_targets_stored[target] = self.is_stored(run_id, target)

if this_target_is_stored:
if _targets_stored[target]:
return _targets_stored

# Need to init the class e.g. if we want to allow depends_on which is not a class attribute
plugin = self._plugin_class_registry[target]()
dependencies = strax.to_str_tuple(plugin.depends_on)
if not dependencies:
if not plugin.depends_on:
raise strax.DataNotAvailable(f"Lowest level dependency {target} is not stored")

forbidden = strax.to_str_tuple(self.context_config["forbid_creation_of"])
Expand All @@ -2562,7 +2561,7 @@ def stored_dependencies(

self.stored_dependencies(
run_id,
target=dependencies,
target=plugin.depends_on,
check_forbidden=check_forbidden,
_targets_stored=_targets_stored,
)
Expand Down Expand Up @@ -2658,21 +2657,33 @@ def provided_dtypes(self, runid="0"):
for data_type, _hash, save_when, version in hashes
}

def get_dependencies(self, data_type):
"""Get the dependencies of a data_type."""
dependencies = set()
def get_dependency_plugins(
self,
target: str,
run_id: str,
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
) -> ty.Dict[str, strax.Plugin]:
"""Return all plugins required to produce targets.
def _get_dependencies(_data_type):
if _data_type in self.root_data_types:
return
plugin = self._plugin_class_registry[_data_type]()
dependencies.update(plugin.depends_on)
for d in plugin.depends_on:
_get_dependencies(d)
:param target: data type to produce
:param run_id: run id to use for run-dependent config options
:param chunk_number: Chunk number to use for run-dependent config options
:return: dictionary with data type as key and plugin as value
_get_dependencies(data_type)
"""
# Get all plugins required to produce targets
plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number)[target]
_dependencies = [plugins.deps.items()]
_dependencies += [
self.get_dependency_plugins(d, run_id, chunk_number).items() for d in plugins.deps
]
dependencies = dict(itertools.chain.from_iterable(_dependencies))
return dependencies

def get_dependencies(self, data_type):
"""Get the dependencies of a data_type."""
return set(self.get_dependency_plugins(data_type, "0").keys())

@property
def root_data_types(self):
"""Root data_type that does not depend on anything."""
Expand Down Expand Up @@ -2737,6 +2748,57 @@ def check_support_superrun(data_type, checked=set(), seen_allow=None):
seen_allow = None
checked |= check_support_superrun(data_type, checked, seen_allow)

@property
def tree_levels(self):
"""Get the levels of the data types in the context.
This function will be useful to tell us which data_type to process first.
For Example, for a given class with Records, Peaks registered, the tree_levels will return:
{'records': {'level': 0, 'class': 'Records', 'index': 0, 'order': 0}, 'peaks': {'level': 1,
'class': 'Peaks', 'index': 0, 'order': 1}}
"""

context_hash = self._context_hash()
if self._fixed_level_cache is not None and context_hash in self._fixed_level_cache:
return self._fixed_level_cache[context_hash]

def _get_levels(data_type=None, results=None):
"""Get the level data_type in the context."""
if results is None:
results = dict()
for k in [data_type] if data_type else self._plugin_class_registry.keys():
results[k] = dict()
_v = self._plugin_class_registry[k]()
if _v.depends_on:
results[k]["level"] = (
max(_get_levels(d, results)[d]["level"] for d in _v.depends_on) + 1
)
else:
results[k]["level"] = 0
results[k]["class"] = self._plugin_class_registry[k].__name__
results[k]["index"] = _v.provides.index(k)
return results

# Sort the results by level, class, and index in provides
_results = sorted(
_get_levels().items(), key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"])
)

# Assign order to the results
for order, (key, value) in enumerate(_results):
value["order"] = order
results = dict(_results)

if self._fixed_level_cache is None:
self._fixed_level_cache = {context_hash: results}
elif context_hash not in self._fixed_level_cache:
self.log.info("Replacing context._fixed_level_cache since plugins/versions changed")
self._fixed_level_cache = {context_hash: results}

return results

@classmethod
def add_method(cls, f):
"""Add f as a new Context method."""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,15 @@ def test_per_chunk_storage():
st.register(p)
with pytest.raises(ValueError):
st.make(run_id, "whatever", chunk_number={"records": [0]})


def test_dependency_tree():
with tempfile.TemporaryDirectory() as temp_dir:
st = strax.Context(
storage=strax.DataDirectory(temp_dir, deep_scan=True),
register=[Records, Peaks],
use_per_run_defaults=True,
)
st.tree
st.inversed_tree
st.tree_levels
2 changes: 0 additions & 2 deletions tests/test_superruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,6 @@ def test_only_combining_superruns(self):
The test also shows the difference between the two.
"""
self.context.tree
self.context.inversed_tree
self.context.check_superrun()
sum_super = self.context.get_array(self.superrun_name, "sum", save="sum")
assert self.context.is_stored(self.superrun_name, "sum")
Expand Down

0 comments on commit ad1c13e

Please sign in to comment.