Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38041: rewrite pre-exec-init logic to work without QGs and respect storage class differences #444

Merged
merged 16 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changes/DM-38041.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add support for initializing processing output runs with just a pipeline graph, not a quantum graph.

This also moves much of the logic for initializing output runs from `lsst.ctrl.mpexec.PreExecInit` to `PipelineGraph` and `QuantumGraph` methods.
8 changes: 4 additions & 4 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ def constructGraph(
container = {}
datasetDict = _DatasetTracker(createInverse=True)
taskToQuantumNode: defaultdict[TaskDef, set[QuantumNode]] = defaultdict(set)
initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
initInputRefs: dict[str, list[DatasetRef]] = {}
initOutputRefs: dict[str, list[DatasetRef]] = {}

if universe is not None:
if not universe.isCompatibleWith(self.infoMappings.universe):
Expand Down Expand Up @@ -597,11 +597,11 @@ def constructGraph(

# initInputRefs and initOutputRefs are optional
if (refs := taskDefDump.get("initInputRefs")) is not None:
initInputRefs[recreatedTaskDef] = [
initInputRefs[recreatedTaskDef.label] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]
if (refs := taskDefDump.get("initOutputRefs")) is not None:
initOutputRefs[recreatedTaskDef] = [
initOutputRefs[recreatedTaskDef.label] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]

Expand Down
271 changes: 262 additions & 9 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,28 @@

import networkx as nx
from lsst.daf.butler import (
Config,
DatasetId,
DatasetRef,
DatasetType,
DimensionRecordsAccumulator,
DimensionUniverse,
LimitedButler,
Quantum,
QuantumBackedButler,
)
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler.persistence_context import PersistenceContextVars
from lsst.daf.butler.registry import ConflictingDefinitionError
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import get_full_type_name
from lsst.utils.packages import Packages
from networkx.drawing.nx_agraph import write_dot

from ..config import PipelineTaskConfig
from ..connections import iterConnections
from ..pipeline import TaskDef
from ..pipeline_graph import PipelineGraph
from ..pipeline_graph import PipelineGraph, compare_packages, log_config_mismatch
from ._implDetails import DatasetTypeName, _DatasetTracker
from ._loadHelpers import LoadHelper
from ._versionDeserializers import DESERIALIZER_MAP
Expand Down Expand Up @@ -286,14 +292,14 @@ def _buildGraphs(
# insertion
self._taskToQuantumNode = dict(self._taskToQuantumNode.items())

self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
self._initInputRefs: dict[str, list[DatasetRef]] = {}
self._initOutputRefs: dict[str, list[DatasetRef]] = {}
self._globalInitOutputRefs: list[DatasetRef] = []
self._registryDatasetTypes: list[DatasetType] = []
if initInputs is not None:
self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
self._initInputRefs = {taskDef.label: list(refs) for taskDef, refs in initInputs.items()}
if initOutputs is not None:
self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
self._initOutputRefs = {taskDef.label: list(refs) for taskDef, refs in initOutputs.items()}
if globalInitOutputs is not None:
self._globalInitOutputRefs = list(globalInitOutputs)
if registryDatasetTypes is not None:
Expand Down Expand Up @@ -812,6 +818,38 @@ def metadata(self) -> MappingProxyType[str, Any]:
"""
return MappingProxyType(self._metadata)

def get_init_input_refs(self, task_label: str) -> list[DatasetRef]:
"""Return the DatasetRefs for the given task's init inputs.

Parameters
----------
task_label : `str`
Label of the task.

Returns
-------
refs : `list` [ `lsst.daf.butler.DatasetRef` ]
Dataset references. Guaranteed to be a new list, not internal
state.
"""
return list(self._initInputRefs.get(task_label, ()))

def get_init_output_refs(self, task_label: str) -> list[DatasetRef]:
"""Return the DatasetRefs for the given task's init outputs.

Parameters
----------
task_label : `str`
Label of the task.

Returns
-------
refs : `list` [ `lsst.daf.butler.DatasetRef` ]
Dataset references. Guaranteed to be a new list, not internal
state.
"""
return list(self._initOutputRefs.get(task_label, ()))

def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
"""Return DatasetRefs for a given task InitInputs.

Expand All @@ -826,7 +864,7 @@ def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
DatasetRef for the task InitInput, can be `None`. This can return
either resolved or non-resolved reference.
"""
return self._initInputRefs.get(taskDef)
return self._initInputRefs.get(taskDef.label)

def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
"""Return DatasetRefs for a given task InitOutputs.
Expand All @@ -843,7 +881,7 @@ def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
either resolved or non-resolved reference. Resolved reference will
match Quantum's initInputs if this is an intermediate dataset type.
"""
return self._initOutputRefs.get(taskDef)
return self._initOutputRefs.get(taskDef.label)

def globalInitOutputRefs(self) -> list[DatasetRef]:
"""Return DatasetRefs for global InitOutputs.
Expand Down Expand Up @@ -1027,9 +1065,9 @@ def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[
taskDef.config.saveToStream(stream)
taskDescription["config"] = stream.getvalue()
taskDescription["label"] = taskDef.label
if (refs := self._initInputRefs.get(taskDef)) is not None:
if (refs := self._initInputRefs.get(taskDef.label)) is not None:
taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
if (refs := self._initOutputRefs.get(taskDef)) is not None:
if (refs := self._initOutputRefs.get(taskDef.label)) is not None:
taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]

inputs = []
Expand Down Expand Up @@ -1403,3 +1441,218 @@ def getSummary(self) -> QgraphSummary:
qts.numOutputs[k.name] += 1

return summary

def make_init_qbb(
self,
butler_config: Config | ResourcePathExpression,
*,
config_search_paths: Iterable[str] | None = None,
) -> QuantumBackedButler:
"""Construct an quantum-backed butler suitable for reading and writing
init input and init output datasets, respectively.

This requires the full graph to have been loaded.

Parameters
----------
butler_config : `~lsst.daf.butler.Config` or \
`~lsst.resources.ResourcePathExpression`
A butler repository root, configuration filename, or configuration
instance.
config_search_paths : `~collections.abc.Iterable` [ `str` ], optional
Additional search paths for butler configuration.

Returns
-------
qbb : `~lsst.daf.butler.QuantumBackedButler`
A limited butler that can ``get`` init-input datasets and ``put``
init-output datasets.
"""
universe = self.universe
# Collect all init input/output dataset IDs.
predicted_inputs: set[DatasetId] = set()
predicted_outputs: set[DatasetId] = set()
pipeline_graph = self.pipeline_graph
for task_label in pipeline_graph.tasks:
predicted_inputs.update(ref.id for ref in self.get_init_input_refs(task_label))
predicted_outputs.update(ref.id for ref in self.get_init_output_refs(task_label))
predicted_outputs.update(ref.id for ref in self.globalInitOutputRefs())
# remove intermediates from inputs
predicted_inputs -= predicted_outputs
# Very inefficient way to extract datastore records from quantum graph,
# we have to scan all quanta and look at their datastore records.
datastore_records: dict[str, DatastoreRecordData] = {}
for quantum_node in self:
for store_name, records in quantum_node.quantum.datastore_records.items():
subset = records.subset(predicted_inputs)
if subset is not None:
datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset)

dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}
# Make butler from everything.
return QuantumBackedButler.from_predicted(
config=butler_config,
predicted_inputs=predicted_inputs,
predicted_outputs=predicted_outputs,
dimensions=universe,
datastore_records=datastore_records,
search_paths=list(config_search_paths) if config_search_paths is not None else None,
dataset_types=dataset_types,
)

def write_init_outputs(self, butler: LimitedButler, skip_existing: bool = True) -> None:
"""Write the init-output datasets for all tasks in the quantum graph.

Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
skip_existing : `bool`, optional
If `True` (default) ignore init-outputs that already exist. If
`False`, raise.

Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if an init-output dataset already exists and
``skip_existing=False``.
"""
# Extract init-input and init-output refs from the QG.
input_refs: dict[str, DatasetRef] = {}
output_refs: dict[str, DatasetRef] = {}
for task_node in self.pipeline_graph.tasks.values():
input_refs.update(
{ref.datasetType.name: ref for ref in self.get_init_input_refs(task_node.label)}
)
output_refs.update(
{
ref.datasetType.name: ref
for ref in self.get_init_output_refs(task_node.label)
if ref.datasetType.name != task_node.init.config_output.dataset_type_name
}
)
for ref, is_stored in butler.stored_many(output_refs.values()).items():
if is_stored:
if not skip_existing:
raise ConflictingDefinitionError(f"Init-output dataset {ref} already exists.")
# We'll `put` whatever's left in output_refs at the end.
del output_refs[ref.datasetType.name]
# Instantiate tasks, reading overall init-inputs and gathering
# init-output in-memory objects.
init_outputs: list[tuple[Any, DatasetType]] = []
self.pipeline_graph.instantiate_tasks(
get_init_input=lambda dataset_type: butler.get(
input_refs[dataset_type.name].overrideStorageClass(dataset_type.storageClass)
),
init_outputs=init_outputs,
)
# Write init-outputs that weren't already present.
for obj, dataset_type in init_outputs:
if new_ref := output_refs.get(dataset_type.name):
assert (
new_ref.datasetType.storageClass_name == dataset_type.storageClass_name
), "QG init refs should use task connection storage classes."
butler.put(obj, new_ref)

def write_configs(self, butler: LimitedButler, compare_existing: bool = True) -> None:
"""Write the config datasets for all tasks in the quantum graph.

Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
compare_existing : `bool`, optional
If `True` check configs that already exist for consistency. If
`False`, always raise if configs already exist.

Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if an config dataset already exists and
``compare_existing=False``, or if the existing config is not
consistent with the config in the quantum graph.
"""
to_put: list[tuple[PipelineTaskConfig, DatasetRef]] = []
for task_node in self.pipeline_graph.tasks.values():
dataset_type_name = task_node.init.config_output.dataset_type_name
(ref,) = [
ref
for ref in self.get_init_output_refs(task_node.label)
if ref.datasetType.name == dataset_type_name
]
try:
old_config = butler.get(ref)
except (LookupError, FileNotFoundError):
old_config = None
if old_config is not None:
if not compare_existing:
raise ConflictingDefinitionError(f"Config dataset {ref} already exists.")
if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch):
raise ConflictingDefinitionError(
f"Config does not match existing task config {dataset_type_name!r} in "
"butler; tasks configurations must be consistent within the same run collection."
)
else:
to_put.append((task_node.config, ref))
# We do writes at the end to minimize the mess we leave behind when we
# raise an exception.
for config, ref in to_put:
butler.put(config, ref)

def write_packages(self, butler: LimitedButler, compare_existing: bool = True) -> None:
"""Write the 'packages' dataset for the currently-active software
versions.

Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
compare_existing : `bool`, optional
If `True` check packages that already exist for consistency. If
`False`, always raise if the packages dataset already exists.

Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if the packages dataset already exists and is not consistent
with the current packages.
"""
new_packages = Packages.fromSystem()
(ref,) = self.globalInitOutputRefs()
try:
packages = butler.get(ref)
except (LookupError, FileNotFoundError):
packages = None
if packages is not None:
if not compare_existing:
raise ConflictingDefinitionError(f"Packages dataset {ref} already exists.")
if compare_packages(packages, new_packages):
# have to remove existing dataset first; butler has no
# replace option.
butler.pruneDatasets([ref], unstore=True, purge=True)
butler.put(packages, ref)
else:
butler.put(new_packages, ref)

def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None:
"""Initialize a new output RUN collection by writing init-output
datasets (including configs and packages).

Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
existing : `bool`, optional
If `True` check or ignore outputs that already exist. If
`False`, always raise if an output dataset already exists.

Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if there are existing init output datasets, and either
``existing=False`` or their contents are not compatible with this
graph.
"""
self.write_configs(butler, compare_existing=existing)
self.write_packages(butler, compare_existing=existing)
self.write_init_outputs(butler, skip_existing=existing)
Loading
Loading