diff --git a/doc/changes/DM-38041.feature.md b/doc/changes/DM-38041.feature.md
new file mode 100644
index 000000000..e2ef1ae92
--- /dev/null
+++ b/doc/changes/DM-38041.feature.md
@@ -0,0 +1,3 @@
+Add support for initializing processing output runs with just a pipeline graph, not a quantum graph.
+
+This also moves much of the logic for initializing output runs from `lsst.ctrl.mpexec.PreExecInit` to `PipelineGraph` and `QuantumGraph` methods.
diff --git a/python/lsst/pipe/base/graph/_versionDeserializers.py b/python/lsst/pipe/base/graph/_versionDeserializers.py
index 0bcf020f5..908e5d681 100644
--- a/python/lsst/pipe/base/graph/_versionDeserializers.py
+++ b/python/lsst/pipe/base/graph/_versionDeserializers.py
@@ -545,8 +545,8 @@ def constructGraph(
container = {}
datasetDict = _DatasetTracker(createInverse=True)
taskToQuantumNode: defaultdict[TaskDef, set[QuantumNode]] = defaultdict(set)
- initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
- initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
+ initInputRefs: dict[str, list[DatasetRef]] = {}
+ initOutputRefs: dict[str, list[DatasetRef]] = {}
if universe is not None:
if not universe.isCompatibleWith(self.infoMappings.universe):
@@ -597,11 +597,11 @@ def constructGraph(
# initInputRefs and initOutputRefs are optional
if (refs := taskDefDump.get("initInputRefs")) is not None:
- initInputRefs[recreatedTaskDef] = [
+ initInputRefs[recreatedTaskDef.label] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]
if (refs := taskDefDump.get("initOutputRefs")) is not None:
- initOutputRefs[recreatedTaskDef] = [
+ initOutputRefs[recreatedTaskDef.label] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]
diff --git a/python/lsst/pipe/base/graph/graph.py b/python/lsst/pipe/base/graph/graph.py
index 3b41989b2..0cce2b3d7 100644
--- a/python/lsst/pipe/base/graph/graph.py
+++ b/python/lsst/pipe/base/graph/graph.py
@@ -46,22 +46,28 @@
import networkx as nx
from lsst.daf.butler import (
+ Config,
DatasetId,
DatasetRef,
DatasetType,
DimensionRecordsAccumulator,
DimensionUniverse,
+ LimitedButler,
Quantum,
+ QuantumBackedButler,
)
+from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler.persistence_context import PersistenceContextVars
+from lsst.daf.butler.registry import ConflictingDefinitionError
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import get_full_type_name
from lsst.utils.packages import Packages
from networkx.drawing.nx_agraph import write_dot
+from ..config import PipelineTaskConfig
from ..connections import iterConnections
from ..pipeline import TaskDef
-from ..pipeline_graph import PipelineGraph
+from ..pipeline_graph import PipelineGraph, compare_packages, log_config_mismatch
from ._implDetails import DatasetTypeName, _DatasetTracker
from ._loadHelpers import LoadHelper
from ._versionDeserializers import DESERIALIZER_MAP
@@ -286,14 +292,14 @@ def _buildGraphs(
# insertion
self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
- self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
- self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
+ self._initInputRefs: dict[str, list[DatasetRef]] = {}
+ self._initOutputRefs: dict[str, list[DatasetRef]] = {}
self._globalInitOutputRefs: list[DatasetRef] = []
self._registryDatasetTypes: list[DatasetType] = []
if initInputs is not None:
- self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
+ self._initInputRefs = {taskDef.label: list(refs) for taskDef, refs in initInputs.items()}
if initOutputs is not None:
- self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
+ self._initOutputRefs = {taskDef.label: list(refs) for taskDef, refs in initOutputs.items()}
if globalInitOutputs is not None:
self._globalInitOutputRefs = list(globalInitOutputs)
if registryDatasetTypes is not None:
@@ -812,6 +818,38 @@ def metadata(self) -> MappingProxyType[str, Any]:
"""
return MappingProxyType(self._metadata)
+ def get_init_input_refs(self, task_label: str) -> list[DatasetRef]:
+ """Return the DatasetRefs for the given task's init inputs.
+
+ Parameters
+ ----------
+ task_label : `str`
+ Label of the task.
+
+ Returns
+ -------
+ refs : `list` [ `lsst.daf.butler.DatasetRef` ]
+ Dataset references. Guaranteed to be a new list, not internal
+ state.
+ """
+ return list(self._initInputRefs.get(task_label, ()))
+
+ def get_init_output_refs(self, task_label: str) -> list[DatasetRef]:
+ """Return the DatasetRefs for the given task's init outputs.
+
+ Parameters
+ ----------
+ task_label : `str`
+ Label of the task.
+
+ Returns
+ -------
+ refs : `list` [ `lsst.daf.butler.DatasetRef` ]
+ Dataset references. Guaranteed to be a new list, not internal
+ state.
+ """
+ return list(self._initOutputRefs.get(task_label, ()))
+
def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
"""Return DatasetRefs for a given task InitInputs.
@@ -826,7 +864,7 @@ def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
DatasetRef for the task InitInput, can be `None`. This can return
either resolved or non-resolved reference.
"""
- return self._initInputRefs.get(taskDef)
+ return self._initInputRefs.get(taskDef.label)
def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
"""Return DatasetRefs for a given task InitOutputs.
@@ -843,7 +881,7 @@ def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
either resolved or non-resolved reference. Resolved reference will
match Quantum's initInputs if this is an intermediate dataset type.
"""
- return self._initOutputRefs.get(taskDef)
+ return self._initOutputRefs.get(taskDef.label)
def globalInitOutputRefs(self) -> list[DatasetRef]:
"""Return DatasetRefs for global InitOutputs.
@@ -1027,9 +1065,9 @@ def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[
taskDef.config.saveToStream(stream)
taskDescription["config"] = stream.getvalue()
taskDescription["label"] = taskDef.label
- if (refs := self._initInputRefs.get(taskDef)) is not None:
+ if (refs := self._initInputRefs.get(taskDef.label)) is not None:
taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
- if (refs := self._initOutputRefs.get(taskDef)) is not None:
+ if (refs := self._initOutputRefs.get(taskDef.label)) is not None:
taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
inputs = []
@@ -1403,3 +1441,218 @@ def getSummary(self) -> QgraphSummary:
qts.numOutputs[k.name] += 1
return summary
+
+ def make_init_qbb(
+ self,
+ butler_config: Config | ResourcePathExpression,
+ *,
+ config_search_paths: Iterable[str] | None = None,
+ ) -> QuantumBackedButler:
+ """Construct an quantum-backed butler suitable for reading and writing
+ init input and init output datasets, respectively.
+
+ This requires the full graph to have been loaded.
+
+ Parameters
+ ----------
+ butler_config : `~lsst.daf.butler.Config` or \
+ `~lsst.resources.ResourcePathExpression`
+ A butler repository root, configuration filename, or configuration
+ instance.
+ config_search_paths : `~collections.abc.Iterable` [ `str` ], optional
+ Additional search paths for butler configuration.
+
+ Returns
+ -------
+ qbb : `~lsst.daf.butler.QuantumBackedButler`
+ A limited butler that can ``get`` init-input datasets and ``put``
+ init-output datasets.
+ """
+ universe = self.universe
+ # Collect all init input/output dataset IDs.
+ predicted_inputs: set[DatasetId] = set()
+ predicted_outputs: set[DatasetId] = set()
+ pipeline_graph = self.pipeline_graph
+ for task_label in pipeline_graph.tasks:
+ predicted_inputs.update(ref.id for ref in self.get_init_input_refs(task_label))
+ predicted_outputs.update(ref.id for ref in self.get_init_output_refs(task_label))
+ predicted_outputs.update(ref.id for ref in self.globalInitOutputRefs())
+ # remove intermediates from inputs
+ predicted_inputs -= predicted_outputs
+ # Very inefficient way to extract datastore records from quantum graph,
+ # we have to scan all quanta and look at their datastore records.
+ datastore_records: dict[str, DatastoreRecordData] = {}
+ for quantum_node in self:
+ for store_name, records in quantum_node.quantum.datastore_records.items():
+ subset = records.subset(predicted_inputs)
+ if subset is not None:
+ datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset)
+
+ dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}
+ # Make butler from everything.
+ return QuantumBackedButler.from_predicted(
+ config=butler_config,
+ predicted_inputs=predicted_inputs,
+ predicted_outputs=predicted_outputs,
+ dimensions=universe,
+ datastore_records=datastore_records,
+ search_paths=list(config_search_paths) if config_search_paths is not None else None,
+ dataset_types=dataset_types,
+ )
+
+ def write_init_outputs(self, butler: LimitedButler, skip_existing: bool = True) -> None:
+ """Write the init-output datasets for all tasks in the quantum graph.
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.LimitedButler`
+ A limited butler data repository client.
+ skip_existing : `bool`, optional
+ If `True` (default) ignore init-outputs that already exist. If
+ `False`, raise.
+
+ Raises
+ ------
+ lsst.daf.butler.registry.ConflictingDefinitionError
+ Raised if an init-output dataset already exists and
+ ``skip_existing=False``.
+ """
+ # Extract init-input and init-output refs from the QG.
+ input_refs: dict[str, DatasetRef] = {}
+ output_refs: dict[str, DatasetRef] = {}
+ for task_node in self.pipeline_graph.tasks.values():
+ input_refs.update(
+ {ref.datasetType.name: ref for ref in self.get_init_input_refs(task_node.label)}
+ )
+ output_refs.update(
+ {
+ ref.datasetType.name: ref
+ for ref in self.get_init_output_refs(task_node.label)
+ if ref.datasetType.name != task_node.init.config_output.dataset_type_name
+ }
+ )
+ for ref, is_stored in butler.stored_many(output_refs.values()).items():
+ if is_stored:
+ if not skip_existing:
+ raise ConflictingDefinitionError(f"Init-output dataset {ref} already exists.")
+ # We'll `put` whatever's left in output_refs at the end.
+ del output_refs[ref.datasetType.name]
+ # Instantiate tasks, reading overall init-inputs and gathering
+ # init-output in-memory objects.
+ init_outputs: list[tuple[Any, DatasetType]] = []
+ self.pipeline_graph.instantiate_tasks(
+ get_init_input=lambda dataset_type: butler.get(
+ input_refs[dataset_type.name].overrideStorageClass(dataset_type.storageClass)
+ ),
+ init_outputs=init_outputs,
+ )
+ # Write init-outputs that weren't already present.
+ for obj, dataset_type in init_outputs:
+ if new_ref := output_refs.get(dataset_type.name):
+ assert (
+ new_ref.datasetType.storageClass_name == dataset_type.storageClass_name
+ ), "QG init refs should use task connection storage classes."
+ butler.put(obj, new_ref)
+
+ def write_configs(self, butler: LimitedButler, compare_existing: bool = True) -> None:
+ """Write the config datasets for all tasks in the quantum graph.
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.LimitedButler`
+ A limited butler data repository client.
+ compare_existing : `bool`, optional
+ If `True` check configs that already exist for consistency. If
+ `False`, always raise if configs already exist.
+
+ Raises
+ ------
+ lsst.daf.butler.registry.ConflictingDefinitionError
+ Raised if an config dataset already exists and
+ ``compare_existing=False``, or if the existing config is not
+ consistent with the config in the quantum graph.
+ """
+ to_put: list[tuple[PipelineTaskConfig, DatasetRef]] = []
+ for task_node in self.pipeline_graph.tasks.values():
+ dataset_type_name = task_node.init.config_output.dataset_type_name
+ (ref,) = [
+ ref
+ for ref in self.get_init_output_refs(task_node.label)
+ if ref.datasetType.name == dataset_type_name
+ ]
+ try:
+ old_config = butler.get(ref)
+ except (LookupError, FileNotFoundError):
+ old_config = None
+ if old_config is not None:
+ if not compare_existing:
+ raise ConflictingDefinitionError(f"Config dataset {ref} already exists.")
+ if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch):
+ raise ConflictingDefinitionError(
+ f"Config does not match existing task config {dataset_type_name!r} in "
+ "butler; tasks configurations must be consistent within the same run collection."
+ )
+ else:
+ to_put.append((task_node.config, ref))
+ # We do writes at the end to minimize the mess we leave behind when we
+ # raise an exception.
+ for config, ref in to_put:
+ butler.put(config, ref)
+
+ def write_packages(self, butler: LimitedButler, compare_existing: bool = True) -> None:
+ """Write the 'packages' dataset for the currently-active software
+ versions.
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.LimitedButler`
+ A limited butler data repository client.
+ compare_existing : `bool`, optional
+ If `True` check packages that already exist for consistency. If
+ `False`, always raise if the packages dataset already exists.
+
+ Raises
+ ------
+ lsst.daf.butler.registry.ConflictingDefinitionError
+ Raised if the packages dataset already exists and is not consistent
+ with the current packages.
+ """
+ new_packages = Packages.fromSystem()
+ (ref,) = self.globalInitOutputRefs()
+ try:
+ packages = butler.get(ref)
+ except (LookupError, FileNotFoundError):
+ packages = None
+ if packages is not None:
+ if not compare_existing:
+ raise ConflictingDefinitionError(f"Packages dataset {ref} already exists.")
+ if compare_packages(packages, new_packages):
+ # have to remove existing dataset first; butler has no
+ # replace option.
+ butler.pruneDatasets([ref], unstore=True, purge=True)
+ butler.put(packages, ref)
+ else:
+ butler.put(new_packages, ref)
+
+ def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None:
+ """Initialize a new output RUN collection by writing init-output
+ datasets (including configs and packages).
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.LimitedButler`
+ A limited butler data repository client.
+ existing : `bool`, optional
+ If `True` check or ignore outputs that already exist. If
+ `False`, always raise if an output dataset already exists.
+
+ Raises
+ ------
+ lsst.daf.butler.registry.ConflictingDefinitionError
+ Raised if there are existing init output datasets, and either
+ ``existing=False`` or their contents are not compatible with this
+ graph.
+ """
+ self.write_configs(butler, compare_existing=existing)
+ self.write_packages(butler, compare_existing=existing)
+ self.write_init_outputs(butler, skip_existing=existing)
diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
index f2e7ddf48..e17bdd775 100644
--- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
+++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
@@ -26,21 +26,34 @@
# along with this program. If not, see .
from __future__ import annotations
-__all__ = ("PipelineGraph",)
+__all__ = ("PipelineGraph", "log_config_mismatch", "compare_packages")
import gzip
import itertools
import json
+import logging
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast
import networkx
import networkx.algorithms.bipartite
import networkx.algorithms.dag
-from lsst.daf.butler import DataCoordinate, DataId, DatasetType, DimensionGroup, DimensionUniverse, Registry
-from lsst.daf.butler.registry import MissingDatasetTypeError
+from lsst.daf.butler import (
+ Butler,
+ DataCoordinate,
+ DataId,
+ DatasetRef,
+ DatasetType,
+ DimensionGroup,
+ DimensionUniverse,
+ MissingDatasetTypeError,
+)
+from lsst.daf.butler.registry import ConflictingDefinitionError, Registry
from lsst.resources import ResourcePath, ResourcePathExpression
+from lsst.utils.packages import Packages
+from .._dataset_handle import InMemoryDatasetHandle
+from ..automatic_connection_constants import PACKAGES_INIT_OUTPUT_NAME, PACKAGES_INIT_OUTPUT_STORAGE_CLASS
from ._dataset_types import DatasetTypeNode
from ._edges import Edge, ReadEdge, WriteEdge
from ._exceptions import (
@@ -65,6 +78,8 @@
_G = TypeVar("_G", bound=networkx.DiGraph | networkx.MultiDiGraph)
+_LOG = logging.getLogger("lsst.pipe.base.pipeline_graph")
+
class PipelineGraph:
"""A graph representation of fully-configured pipeline.
@@ -1524,6 +1539,343 @@ def split_independent(self) -> Iterable[PipelineGraph]:
component_subgraph.sort()
yield component_subgraph
+ ###########################################################################
+ #
+ # Data repository/collection initialization
+ #
+ ###########################################################################
+
+ @property
+ def packages_dataset_type(self) -> DatasetType:
+ """The special "packages" dataset type that records software versions.
+
+ This is not associated with a task and hence is
+ not considered part of the pipeline graph in other respects, but it
+ does get written with other provenance datasets.
+ """
+ if self.universe is None:
+ raise UnresolvedGraphError(
+ "PipelineGraph must be resolved in order to get the packages dataset type."
+ )
+ return DatasetType(PACKAGES_INIT_OUTPUT_NAME, self.universe.empty, PACKAGES_INIT_OUTPUT_STORAGE_CLASS)
+
+ def register_dataset_types(self, butler: Butler, include_packages: bool = True) -> None:
+ """Register all dataset types in a data repository.
+
+ Parameters
+ ----------
+ butler : `~lsst.daf.butler.Butler`
+ Data repository client.
+ include_packages : `bool`, optional
+ Whether to include the special "packages" dataset type that records
+ software versions (this is not associated with a task and hence is
+ not considered part of the pipeline graph in other respects, but it
+ does get written with other provenance datasets).
+ """
+ dataset_types = [node.dataset_type for node in self.dataset_types.values()]
+ if include_packages:
+ dataset_types.append(self.packages_dataset_type)
+ for dataset_type in dataset_types:
+ butler.registry.registerDatasetType(dataset_type)
+
+ def check_dataset_type_registrations(self, butler: Butler, include_packages: bool = True) -> None:
+ """Check that dataset type registrations in a data repository match
+ the definitions in this pipeline graph.
+
+ Parameters
+ ----------
+ butler : `~lsst.daf.butler.Butler`
+ Data repository client.
+ include_packages : `bool`, optional
+ Whether to include the special "packages" dataset type that records
+ software versions (this is not associated with a task and hence is
+ not considered part of the pipeline graph in other respects, but it
+ does get written with other provenance datasets).
+
+ Raises
+ ------
+ lsst.daf.butler.MissingDatasetTypeError
+ Raised if one or more non-optional-input or output dataset types in
+ the pipeline is not registered at all.
+ lsst.daf.butler.ConflictingDefinitionError
+ Raised if the definition in the data repository is not identical
+ to the definition in the pipeline graph.
+
+ Notes
+ -----
+ Note that dataset type definitions that are storage-class-conversion
+ compatible but not identical are not permitted by these checks, because
+ the expectation is that these differences are handled by `resolve`,
+ which makes the pipeline graph use the data repository definitions.
+ This method is intended to check that none of those definitions have
+ changed.
+ """
+ dataset_types = [node.dataset_type for node in self.dataset_types.values()]
+ if include_packages:
+ dataset_types.append(self.packages_dataset_type)
+ missing_dataset_types: list[str] = []
+ for dataset_type in dataset_types:
+ try:
+ expected = butler.registry.getDatasetType(dataset_type.name)
+ except MissingDatasetTypeError:
+ expected = None
+ if expected is None:
+ # The user probably forgot to register dataset types
+ # at least once (which should be an error),
+ # but we could also get here if this is an optional input for
+ # which no datasets were found in this repo (not an error).
+ if (
+ not (
+ self.producer_of(dataset_type.name) is None
+ and all(
+ self.tasks[input_edge.task_label].is_optional(input_edge.connection_name)
+ for input_edge in self.consuming_edges_of(dataset_type.name)
+ )
+ )
+ or dataset_type.name == PACKAGES_INIT_OUTPUT_NAME
+ ):
+ missing_dataset_types.append(dataset_type.name)
+ elif expected != dataset_type:
+ raise ConflictingDefinitionError(
+ f"DatasetType definition in registry has changed since the pipeline graph was resolved: "
+ f"{dataset_type} (graph) != {expected} (registry)."
+ )
+ if missing_dataset_types:
+ plural = "s" if len(missing_dataset_types) != 1 else ""
+ raise MissingDatasetTypeError(
+ f"Missing dataset type definition{plural}: {', '.join(missing_dataset_types)}. "
+ "Dataset types have to be registered in advance (on the command-line, either via "
+ "`butler register-dataset-type` or the `--register-dataset-types` option to `pipetask run`."
+ )
+
+ def instantiate_tasks(
+ self,
+ get_init_input: Callable[[DatasetType], Any] | None = None,
+ init_outputs: list[tuple[Any, DatasetType]] | None = None,
+ ) -> list[PipelineTask]:
+ """Instantiate all tasks in the pipeline.
+
+ Parameters
+ ----------
+ get_init_input : `~collections.abc.Callable`, optional
+ Callable that accepts a single `~lsst.daf.butler.DatasetType`
+ parameter and returns the init-input dataset associated with that
+ dataset type. Must respect the storage class embedded in the type.
+ This is optional if the pipeline does not have any overall init
+ inputs. When a full butler is available,
+ `lsst.daf.butler.Butler.get` can be used directly here.
+ init_outputs : `list`, optional
+ A list of ``(obj, dataset type)`` init-output dataset pairs, to be
+ appended to in-place. Both the object and the dataset type will
+ correspond to the storage class of the output connection, which
+ may not be the same as the storage class on the graph's dataset
+ type node.
+
+ Returns
+ -------
+ tasks : `list`
+ Constructed `PipelineTask` instances.
+ """
+ if not self.is_fully_resolved:
+ raise UnresolvedGraphError("Pipeline graph must be fully resolved before instantiating tasks.")
+ empty_data_id = DataCoordinate.make_empty(cast(DimensionUniverse, self.universe))
+ handles: dict[str, InMemoryDatasetHandle] = {}
+ tasks: list[PipelineTask] = []
+ for task_node in self.tasks.values():
+ task_init_inputs: dict[str, Any] = {}
+ for read_edge in task_node.init.inputs.values():
+ if (handle := handles.get(read_edge.dataset_type_name)) is not None:
+ obj = handle.get(storageClass=read_edge.storage_class_name)
+ elif (
+ read_edge.component is not None
+ and (parent_handle := handles.get(read_edge.parent_dataset_type_name)) is not None
+ ):
+ obj = parent_handle.get(
+ storageClass=read_edge.storage_class_name, component=read_edge.component
+ )
+ else:
+ dataset_type_node = self.dataset_types[read_edge.parent_dataset_type_name]
+ if get_init_input is None:
+ raise ValueError(
+ f"Task {task_node.label!r} requires init-input "
+ f"{read_edge.dataset_type_name} but no 'get_init_input' callback was provided."
+ )
+ obj = get_init_input(read_edge.adapt_dataset_type(dataset_type_node.dataset_type))
+ n_consumers = len(self.consumers_of(dataset_type_node.name))
+ if (
+ n_consumers > 1
+ and read_edge.component is None
+ and read_edge.storage_class_name == dataset_type_node.storage_class_name
+ ):
+ # Caching what we just got is safe in general only
+ # if there was no storage class conversion, since
+ # a->b and a->c does not imply b->c.
+ handles[read_edge.dataset_type_name] = InMemoryDatasetHandle(
+ obj,
+ storageClass=dataset_type_node.storage_class,
+ dataId=empty_data_id,
+ copy=True,
+ )
+ task_init_inputs[read_edge.connection_name] = obj
+ task = task_node.task_class(
+ config=task_node.config, initInputs=task_init_inputs, name=task_node.label
+ )
+ tasks.append(task)
+ for write_edge in task_node.init.outputs.values():
+ dataset_type_node = self.dataset_types[write_edge.parent_dataset_type_name]
+ obj = getattr(task, write_edge.connection_name)
+ # We don't immediately coerce obj to the dataset_type_node
+ # storage class (which should be the repo storage class, if
+ # there is one) when appending to `init_outputs` because a
+ # formatter might be able to do a better job of that later;
+ # instead we pair it with a dataset type that's consistent with
+ # the in-memory type. We do coerce when populating `handles`,
+ # though, because going through the dataset_type_node storage
+ # class is the conversion path we checked when we resolved the
+ # pipeline graph.
+ if init_outputs is not None:
+ init_outputs.append((obj, write_edge.adapt_dataset_type(dataset_type_node.dataset_type)))
+ n_consumers = len(self.consumers_of(dataset_type_node.name))
+ if n_consumers > 0:
+ handles[dataset_type_node.name] = InMemoryDatasetHandle(
+ dataset_type_node.storage_class.coerce_type(obj),
+ dataId=empty_data_id,
+ storageClass=dataset_type_node.storage_class,
+ copy=(n_consumers > 1),
+ )
+ return tasks
+
+ def write_init_outputs(self, butler: Butler) -> None:
+ """Write the init-output datasets for all tasks in the pipeline graph.
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.Butler`
+ A full butler data repository client with its default run set
+ to the collection where datasets should be written.
+
+ Notes
+ -----
+ Datasets that already exist in the butler's output run collection will
+ not be written.
+
+ This method writes outputs with new random dataset IDs and should
+ hence only be used when writing init-outputs prior to building a
+ `QuantumGraph`. Use `QuantumGraph.write_init_outputs` if a quantum
+ graph has already been built.
+ """
+ init_outputs: list[tuple[Any, DatasetType]] = []
+ self.instantiate_tasks(butler.get, init_outputs)
+ found_refs: dict[str, DatasetRef] = {}
+ to_put: list[tuple[Any, DatasetType]] = []
+ for obj, dataset_type in init_outputs:
+ if (ref := butler.find_dataset(dataset_type, collections=butler.run)) is not None:
+ found_refs[dataset_type.name] = ref
+ else:
+ to_put.append((obj, dataset_type))
+ for ref, stored in butler.stored_many(found_refs.values()).items():
+ if not stored:
+ raise FileNotFoundError(
+ f"Init-output dataset {ref.datasetType.name!r} was found in RUN {ref.run!r} "
+ f"but had not actually been stored (or was stored and later deleted)."
+ )
+ for obj, dataset_type in to_put:
+ butler.put(obj, dataset_type)
+
+ def write_configs(self, butler: Butler) -> None:
+ """Write the config datasets for all tasks in the pipeline graph.
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.Butler`
+ A full butler data repository client with its default run set
+ to the collection where datasets should be written.
+
+ Notes
+ -----
+ Config datasets that already exist in the butler's output run
+ collection will be checked for consistency.
+
+ This method writes outputs with new random dataset IDs and should
+ hence only be used when writing init-outputs prior to building a
+ `QuantumGraph`. Use `QuantumGraph.write_configs` if a quantum graph
+ has already been built.
+
+ Raises
+ ------
+ lsst.daf.butler.registry.ConflictingDefinitionError
+ Raised if a config dataset already exists and is not consistent
+ with the config in the pipeline graph.
+ """
+ to_put: list[tuple[PipelineTaskConfig, str]] = []
+ for task_node in self.tasks.values():
+ dataset_type_name = task_node.init.config_output.dataset_type_name
+ if (ref := butler.find_dataset(dataset_type_name, collections=butler.run)) is not None:
+ old_config = butler.get(ref)
+ if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch):
+ raise ConflictingDefinitionError(
+ f"Config does not match existing task config {dataset_type_name!r} in "
+ "butler; tasks configurations must be consistent within the same run collection"
+ )
+ else:
+ to_put.append((task_node.config, dataset_type_name))
+ # We do writes at the end to minimize the mess we leave behind when we
+ # raise an exception.
+ for config, dataset_type_name in to_put:
+ butler.put(config, dataset_type_name)
+
+ def write_packages(self, butler: Butler) -> None:
+ """Write the 'packages' dataset for the currently-active software
+ versions.
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.Butler`
+ A full butler data repository client with its default run set
+ to the collection where datasets should be written.
+
+ Notes
+ -----
+ If the packages dataset already exists, it will be compared to the
+ versions in the current packages. New packages that weren't present
+ before are not considered an inconsistency.
+
+ This method writes outputs with new random dataset IDs and should
+ hence only be used when writing init-outputs prior to building a
+ `QuantumGraph`. Use `QuantumGraph.write_packages` if a quantum graph
+ has already been built.
+
+ Raises
+ ------
+ lsst.daf.butler.registry.ConflictingDefinitionError
+ Raised if the packages dataset already exists and is not consistent
+ with the current packages.
+ """
+ new_packages = Packages.fromSystem()
+ if (ref := butler.find_dataset(self.packages_dataset_type)) is not None:
+ packages = butler.get(ref)
+ if compare_packages(packages, new_packages):
+ # have to remove existing dataset first; butler has no
+ # replace option.
+ butler.pruneDatasets([ref], unstore=True, purge=True)
+ butler.put(packages, ref)
+ else:
+ butler.put(new_packages, self.packages_dataset_type)
+
+ def init_output_run(self, butler: Butler) -> None:
+ """Initialize a new output RUN collection by writing init-output
+ datasets (including configs and packages).
+
+ Parameters
+ ----------
+ butler : `lsst.daf.butler.Butler`
+ A full butler data repository client with its default run set
+ to the collection where datasets should be written.
+ """
+ self.write_configs(butler)
+ self.write_packages(butler)
+ self.write_init_outputs(butler)
+
###########################################################################
#
# Class- and Package-Private Methods.
@@ -1818,3 +2170,51 @@ def _reset(self) -> None:
_dataset_types: DatasetTypeMappingView
_raw_data_id: dict[str, Any]
_universe: DimensionUniverse | None
+
+
+def log_config_mismatch(msg: str) -> None:
+ """Log messages about configuration mismatch.
+
+ Parameters
+ ----------
+ msg : `str`
+ Log message to use.
+ """
+ _LOG.fatal("Comparing configuration: %s", msg)
+
+
+def compare_packages(packages: Packages, new_packages: Packages) -> bool:
+ """Compare two versions of Packages.
+
+ Parameters
+ ----------
+ packages : `Packages`
+ Previously recorded package versions. Updated in place to include
+ any new packages that weren't present before.
+ new_packages : `Packages`
+ New set of package versions.
+
+ Returns
+ -------
+ updated : `bool`
+ `True` if ``packages`` was updated, `False` if not.
+
+ Raises
+ ------
+ ConflictingDefinitionError
+ Raised if versions are inconsistent.
+ """
+ diff = new_packages.difference(packages)
+ if diff:
+ versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff)
+ raise ConflictingDefinitionError(f"Package versions mismatch: ({versions_str})")
+ else:
+ _LOG.debug("new packages are consistent with old")
+ # Update the old set of packages in case we have more packages
+ # that haven't been persisted.
+ extra = new_packages.extra(packages)
+ if extra:
+ _LOG.debug("extra packages: %s", extra)
+ packages.update(new_packages)
+ return True
+ return False
diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py
index 576869bf4..5c462c5e2 100644
--- a/python/lsst/pipe/base/pipeline_graph/_tasks.py
+++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py
@@ -48,7 +48,7 @@
from .. import automatic_connection_constants as acc
from ..connections import PipelineTaskConnections
-from ..connectionTypes import BaseConnection, InitOutput, Output
+from ..connectionTypes import BaseConnection, BaseInput, InitOutput, Output
from ._edges import Edge, ReadEdge, WriteEdge
from ._exceptions import TaskNotImportedError, UnresolvedGraphError
from ._nodes import NodeKey, NodeType
@@ -216,7 +216,7 @@ class TaskInitNode:
- ``task_class_name``
- ``bipartite`` (see `NodeType.bipartite`)
- ``task_class`` (only if `is_imported` is `True`)
- - ``config`` (only if `is_importd` is `True`)
+ - ``config`` (only if `is_imported` is `True`)
"""
def __init__(
@@ -798,6 +798,23 @@ def get_lookup_function(
"""
return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None)
+ def is_optional(self, connection_name: str) -> bool:
+ """Check whether the given connection has ``minimum==0``.
+
+ Parameters
+ ----------
+ connection_name : `str`
+ Name of the connection.
+
+ Returns
+ -------
+ optional : `bool`
+ Whether this task can run without any datasets for the given
+ connection.
+ """
+ connection = getattr(self.get_connections(), connection_name)
+ return isinstance(connection, BaseInput) and connection.minimum == 0
+
def get_connections(self) -> PipelineTaskConnections:
"""Return the connections class instance for this task.
diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py
index 0c2d7ca23..70931e4af 100644
--- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py
+++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py
@@ -563,7 +563,7 @@ class DynamicConnectionConfig(Config):
default=True,
)
minimum = Field[int](
- doc="Minimum number of datasets per quantum requried for this connection. Ignored for non-inputs.",
+ doc="Minimum number of datasets per quantum required for this connection. Ignored for non-inputs.",
dtype=int,
default=1,
)
diff --git a/tests/test_init_output_run.py b/tests/test_init_output_run.py
new file mode 100644
index 000000000..2cc6341e1
--- /dev/null
+++ b/tests/test_init_output_run.py
@@ -0,0 +1,535 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import itertools
+import tempfile
+import unittest
+from collections.abc import Iterator
+from contextlib import contextmanager
+from typing import ClassVar
+
+import lsst.utils.tests
+from lsst.daf.butler import (
+ Butler,
+ DatasetRef,
+ DatasetType,
+ MissingDatasetTypeError,
+ QuantumBackedButler,
+ SerializedDatasetType,
+ StorageClassFactory,
+)
+from lsst.daf.butler.registry import ConflictingDefinitionError
+from lsst.pipe.base import QuantumGraph
+from lsst.pipe.base.all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder
+from lsst.pipe.base.pipeline_graph import PipelineGraph
+from lsst.pipe.base.tests.mocks import (
+ DynamicConnectionConfig,
+ DynamicTestPipelineTask,
+ DynamicTestPipelineTaskConfig,
+ MockDataset,
+)
+
+
+def _have_example_storage_classes() -> bool:
+ """Check whether some storage classes work as expected.
+
+ Given that these have registered converters, it shouldn't actually be
+ necessary to import those types in order to determine that they're
+ convertible, but the storage class machinery is implemented such that types
+ that can't be imported can't be converted, and while that's inconvenient
+ here it's totally fine in non-testing scenarios where you only care about a
+ storage class if you can actually use it.
+ """
+ getter = StorageClassFactory().getStorageClass
+ return (
+ getter("ArrowTable").can_convert(getter("ArrowAstropy"))
+ and getter("ArrowAstropy").can_convert(getter("ArrowTable"))
+ and getter("ArrowTable").can_convert(getter("DataFrame"))
+ and getter("DataFrame").can_convert(getter("ArrowTable"))
+ )
+
+
+class InitOutputRunTestCase(unittest.TestCase):
+ """Tests for the init_output_run methods of PipelineGraph and
+ QuantumGraph.
+ """
+
+ INPUT_COLLECTION: ClassVar[str] = "overall_inputs"
+
+ @contextmanager
+ def make_butler(self) -> Iterator[Butler]:
+ """Wrap a temporary local butler repository in a context manager."""
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as root:
+ Butler.makeRepo(root)
+ butler = Butler.from_config(root, writeable=True)
+ yield butler
+
+ @contextmanager
+ def prep_butler(self, pipeline_graph: PipelineGraph) -> Iterator[Butler]:
+ """Create a temporary local butler repository with the dataset types
+ and input datasets needed by a pipeline graph.
+
+ This also resolves the pipeline graph and checks dataset types
+ immediately after they are registered, providing test coverage for the
+ methods that do that.
+ """
+ with self.make_butler() as butler:
+ butler.collections.register(self.INPUT_COLLECTION)
+ pipeline_graph.resolve(butler.registry)
+ with self.assertRaises(MissingDatasetTypeError):
+ pipeline_graph.check_dataset_type_registrations(butler)
+ pipeline_graph.register_dataset_types(butler)
+ pipeline_graph.check_dataset_type_registrations(butler)
+ for _, dataset_type_node in pipeline_graph.iter_overall_inputs():
+ butler.put(
+ MockDataset(
+ dataset_id=None,
+ dataset_type=SerializedDatasetType(
+ name=dataset_type_node.name,
+ dimensions=[],
+ storageClass=dataset_type_node.storage_class_name,
+ ),
+ data_id={},
+ run=self.INPUT_COLLECTION,
+ ),
+ dataset_type_node.name,
+ run=self.INPUT_COLLECTION,
+ )
+ yield butler
+
+ def find_init_output_refs(
+ self, pipeline_graph: PipelineGraph, butler: Butler
+ ) -> dict[str, list[DatasetRef]]:
+ """Find the init-output datasets of a pipeline graph in a butler
+ repository.
+
+ Parameters
+ ----------
+ pipeline_graph : `PipelineGraph`
+ Pipeline graph.
+ butler : `Butler`
+ Full butler client.
+
+ Returns
+ -------
+ init_output_refs : `dict`
+ Dataset references, keyed by task label. Storage classes will
+ match the data repository definitions of the dataset types. The
+ special 'packages' dataset type will be included under a '*' key.
+ """
+ init_output_refs: dict[str, list[DatasetRef]] = {}
+ for task_node in pipeline_graph.tasks.values():
+ init_output_refs_for_task: list[DatasetRef] = []
+ for write_edge in task_node.init.iter_all_outputs():
+ ref = butler.find_dataset(write_edge.dataset_type_name)
+ # Check that the ref we got back uses the dataset type node's
+ # definition of the dataset type (including storage class).
+ self.assertEqual(
+ ref.datasetType, pipeline_graph.dataset_types[write_edge.dataset_type_name].dataset_type
+ )
+ # Remember the version of the ref that has the task's storage
+ # class, in case they differ.
+ init_output_refs_for_task.append(write_edge.adapt_dataset_ref(ref))
+ init_output_refs[task_node.label] = init_output_refs_for_task
+ init_output_refs["*"] = [butler.find_dataset(pipeline_graph.packages_dataset_type)]
+ return init_output_refs
+
+ def get_quantum_graph_init_output_refs(self, quantum_graph: QuantumGraph) -> dict[str, list[DatasetRef]]:
+ """Extract dataset references from a QuantumGraph into the same form
+ as returned by `find_init_output_refs`.
+ """
+ init_output_refs: dict[str, list[DatasetRef]] = {}
+ for task_label in quantum_graph.pipeline_graph.tasks:
+ init_output_refs[task_label] = quantum_graph.get_init_output_refs(task_label)
+ init_output_refs["*"] = list(quantum_graph.globalInitOutputRefs())
+ return init_output_refs
+
+ def assert_init_output_refs_equal(
+ self, a: dict[str, list[DatasetRef]], b: dict[str, list[DatasetRef]]
+ ) -> None:
+ """Check that two dictionaries of the form returned by
+ `find_init_output_refs` are equal.
+ """
+ self.assertEqual(a.keys(), b.keys())
+ for task_label, init_output_refs_for_task in a.items():
+ self.assertCountEqual(init_output_refs_for_task, b[task_label])
+
+ def check_qbb_consistency(
+ self, init_output_refs: dict[str, list[DatasetRef]], qbb: QuantumBackedButler
+ ) -> None:
+ """Check that a quantum-backed butler sees all of the given datasets.
+
+ Parameters
+ ----------
+ init_output_refs : `dict`
+ Dataset references, keyed by task label. Storage classes should
+ match the data repository definitions of the dataset types. The
+ special 'packages' dataset type should be included under a '*' key.
+ qbb : `lsst.daf.butler.QuantumBackedButler`
+ A quantum-backed butler.
+ """
+ for task_label, init_output_refs_for_task in init_output_refs.items():
+ for ref, stored_in in qbb.stored_many(init_output_refs_for_task).items():
+ self.assertTrue(
+ stored_in, msg=f"Init-input {ref} of task {task_label} not stored according to QBB."
+ )
+
+ def init_with_pipeline_graph_first(
+ self, pipeline_graph: PipelineGraph, butler: Butler, run: str
+ ) -> QuantumGraph:
+ """Test the init_output_run methods of PipelineGraph and QuantumGraph,
+ using the former to actually write init-outputs (with later attempts
+ correctly failing or doing nothing, depending on parameters).
+ """
+ butler = butler._clone(run=run, collections=[self.INPUT_COLLECTION, run])
+ pipeline_graph.init_output_run(butler)
+ init_output_refs = self.find_init_output_refs(pipeline_graph, butler)
+ # Build a QG with the init outputs already in place.
+ quantum_graph_builder = AllDimensionsQuantumGraphBuilder(
+ pipeline_graph,
+ butler,
+ skip_existing_in=[run],
+ output_run=run,
+ input_collections=[self.INPUT_COLLECTION],
+ )
+ quantum_graph = quantum_graph_builder.build(
+ metadata={"output_run": run}, attach_datastore_records=True
+ )
+ # Check that the QG refs are the same as the ones that were present
+ # already.
+ self.assert_init_output_refs_equal(
+ self.get_quantum_graph_init_output_refs(quantum_graph),
+ init_output_refs,
+ )
+ # Initialize with the pipeline graph, should be a no-op.
+ pipeline_graph.init_output_run(butler)
+ self.assert_init_output_refs_equal(
+ self.find_init_output_refs(pipeline_graph, butler),
+ init_output_refs,
+ )
+ # Initialize again with the QG; should be a no-op.
+ quantum_graph.init_output_run(butler)
+ self.assert_init_output_refs_equal(
+ self.find_init_output_refs(pipeline_graph, butler),
+ init_output_refs,
+ )
+ # Initialize again with the QG but tell it to expect an empty run.
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(butler, existing=False)
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.write_configs(butler, compare_existing=False)
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.write_packages(butler, compare_existing=False)
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.write_init_outputs(butler, skip_existing=False)
+ # Make a QBB, check that it can see the init outputs.
+ qbb = quantum_graph.make_init_qbb(butler._config)
+ self.check_qbb_consistency(init_output_refs, qbb)
+ # Use QBB to initialize again, should be a no-op.
+ quantum_graph.init_output_run(qbb)
+ self.check_qbb_consistency(init_output_refs, qbb)
+ # Use QBB to initialize again but tell it to expect an empty run.
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(qbb, existing=False)
+ return quantum_graph
+
+ def init_with_quantum_graph_first(
+ self, pipeline_graph: PipelineGraph, butler: Butler, run: str
+ ) -> QuantumGraph:
+ """Test the init_output_run methods of PipelineGraph and QuantumGraph,
+ using the latter to actually write init-outputs (with later attempts
+ correctly failing or doing nothing, depending on parameters).
+ """
+ butler = butler._clone(run=run, collections=[self.INPUT_COLLECTION, run])
+ # Build a QG.
+ quantum_graph_builder = AllDimensionsQuantumGraphBuilder(
+ pipeline_graph,
+ butler,
+ input_collections=[self.INPUT_COLLECTION],
+ )
+ quantum_graph = quantum_graph_builder.build(
+ metadata={"output_run": run}, attach_datastore_records=True
+ )
+ # Initialize with the QG.
+ quantum_graph.init_output_run(butler)
+ # Check that the QG refs are the same as the ones we find in the repo.
+ init_output_refs = self.find_init_output_refs(pipeline_graph, butler)
+ self.assert_init_output_refs_equal(
+ self.get_quantum_graph_init_output_refs(quantum_graph),
+ init_output_refs,
+ )
+ # Initialize again with the QG; should be a no-op.
+ quantum_graph.init_output_run(butler)
+ self.assert_init_output_refs_equal(
+ self.find_init_output_refs(pipeline_graph, butler),
+ init_output_refs,
+ )
+ # Initialize again with the QG but tell it to expect an empty run.
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(butler, existing=False)
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.write_configs(butler, compare_existing=False)
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.write_packages(butler, compare_existing=False)
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.write_init_outputs(butler, skip_existing=False)
+ # Initialize with the pipeline graph, should be a no-op.
+ pipeline_graph.init_output_run(butler)
+ # Make a QBB, check that it can see the init outputs.
+ qbb = quantum_graph.make_init_qbb(butler._config)
+ self.check_qbb_consistency(init_output_refs, qbb)
+ # Use QBB to initialize again, should be a no-op.
+ quantum_graph.init_output_run(qbb)
+ self.check_qbb_consistency(init_output_refs, qbb)
+ # Use QBB to initialize again but tell it to expect an empty run.
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(qbb, existing=False)
+ return quantum_graph
+
+ def init_with_qbb_first(self, pipeline_graph: PipelineGraph, butler: Butler, run: str) -> QuantumGraph:
+ """Test the init_output_run methods of PipelineGraph and QuantumGraph,
+ using the latter a quantum-backed butler to actually write init-outputs
+ (with later attempts correctly failing or doing nothing, depending on
+ parameters).
+ """
+ butler = butler._clone(run=run, collections=[self.INPUT_COLLECTION, run])
+ # Build a QG.
+ quantum_graph_builder = AllDimensionsQuantumGraphBuilder(
+ pipeline_graph,
+ butler,
+ input_collections=[self.INPUT_COLLECTION],
+ )
+ quantum_graph = quantum_graph_builder.build(
+ metadata={"output_run": run}, attach_datastore_records=True
+ )
+ # Make a quantum-backed butler and use it to initialize the run.
+ qbb = quantum_graph.make_init_qbb(butler._config)
+ quantum_graph.init_output_run(qbb)
+ init_output_refs = self.get_quantum_graph_init_output_refs(quantum_graph)
+ self.check_qbb_consistency(init_output_refs, qbb)
+ # Use QBB to initialize again, should be a no-op.
+ self.check_qbb_consistency(init_output_refs, qbb)
+ # Use QBB to initialize again but tell it to expect an empty run.
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(qbb, existing=False)
+ # Transferring datasets back to the main butler (i.e. insert DB entries
+ # for them).
+ butler.transfer_from(qbb, itertools.chain.from_iterable(init_output_refs.values()))
+ # Check that the QG refs are the same as the ones we find in the repo.
+ self.assert_init_output_refs_equal(
+ self.find_init_output_refs(pipeline_graph, butler),
+ init_output_refs,
+ )
+ # Initialize again with the QG; should be a no-op.
+ quantum_graph.init_output_run(butler)
+ self.assert_init_output_refs_equal(
+ self.find_init_output_refs(pipeline_graph, butler),
+ init_output_refs,
+ )
+ # Initialize again with the QG but tell it to expect an empty run.
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(butler, existing=False)
+ # Initialize with the pipeline graph, should be a no-op.
+ pipeline_graph.init_output_run(butler)
+ return quantum_graph
+
+ def test_two_tasks_no_conversions(self) -> None:
+ """Test a two-task pipeline with an overall init-input, an overall
+ init-output, and an init-intermediate.
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime")
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime")
+ a_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init")
+ a_config.init_outputs["io"] = DynamicConnectionConfig(dataset_type_name="intermediate_init")
+ b_config = DynamicTestPipelineTaskConfig()
+ b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime")
+ b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ b_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="intermediate_init")
+ b_config.init_outputs["io"] = DynamicConnectionConfig(dataset_type_name="output_init")
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ pipeline_graph.add_task("b", DynamicTestPipelineTask, b_config)
+ with self.prep_butler(pipeline_graph) as butler:
+ self.init_with_pipeline_graph_first(pipeline_graph, butler, "run1")
+ self.assertEqual(butler.get("a_config", collections="run1"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run1"), b_config)
+ self.init_with_quantum_graph_first(pipeline_graph, butler, "run2")
+ self.assertEqual(butler.get("a_config", collections="run2"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run2"), b_config)
+ self.init_with_qbb_first(pipeline_graph, butler, "run3")
+ self.assertEqual(butler.get("a_config", collections="run3"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run3"), b_config)
+
+ def test_optional_input_unregistered(self) -> None:
+ """Test that an optional input dataset type that is not registered is
+ not considered an error.
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime", minimum=0)
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ with self.make_butler() as butler:
+ pipeline_graph.resolve(butler.registry)
+ butler.registry.registerDatasetType(pipeline_graph.dataset_types["a_config"].dataset_type)
+ butler.registry.registerDatasetType(pipeline_graph.dataset_types["a_log"].dataset_type)
+ butler.registry.registerDatasetType(pipeline_graph.dataset_types["a_metadata"].dataset_type)
+ butler.registry.registerDatasetType(pipeline_graph.dataset_types["output_runtime"].dataset_type)
+ pipeline_graph.check_dataset_type_registrations(butler, include_packages=False)
+
+ def test_registration_changed(self) -> None:
+ """Test that we get an error when dataset type registrations in a data
+ repository change between the time a pipeline graph is resolved (e.g.
+ at QG generation) and when dataset types are checked later (e.g. during
+ execution).
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime")
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ with self.make_butler() as butler:
+ pipeline_graph.resolve(butler.registry)
+ pipeline_graph.register_dataset_types(butler)
+ butler.registry.removeDatasetType("input_runtime")
+ butler.registry.registerDatasetType(
+ DatasetType("input_runtime", {"instrument"}, "StructuredDataList", universe=butler.dimensions)
+ )
+ with self.assertRaises(ConflictingDefinitionError):
+ pipeline_graph.check_dataset_type_registrations(butler)
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_init_intermediate_component(self) -> None:
+ """Test init_output_run with an init-intermediate that is written as
+ a composite and read as a component.
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime")
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime")
+ a_config.init_outputs["io"] = DynamicConnectionConfig(
+ dataset_type_name="intermediate_init", storage_class="ArrowTable"
+ )
+ b_config = DynamicTestPipelineTaskConfig()
+ b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime")
+ b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ b_config.init_inputs["ii"] = DynamicConnectionConfig(
+ dataset_type_name="intermediate_init.schema", storage_class="ArrowSchema"
+ )
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ pipeline_graph.add_task("b", DynamicTestPipelineTask, b_config)
+ with self.prep_butler(pipeline_graph) as butler:
+ self.init_with_pipeline_graph_first(pipeline_graph, butler, "run1")
+ self.assertEqual(butler.get("a_config", collections="run1"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run1"), b_config)
+ self.init_with_quantum_graph_first(pipeline_graph, butler, "run2")
+ self.assertEqual(butler.get("a_config", collections="run2"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run2"), b_config)
+ self.init_with_qbb_first(pipeline_graph, butler, "run3")
+ self.assertEqual(butler.get("a_config", collections="run3"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run3"), b_config)
+
+ def test_no_get_init_input_callback(self) -> None:
+ """Test calling PipelineGraph.instantiate_tasks with no get_init_input
+ callback when one is necessary.
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime")
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ a_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init")
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ with self.make_butler() as butler:
+ pipeline_graph.resolve(butler.registry)
+ with self.assertRaises(ValueError):
+ pipeline_graph.instantiate_tasks()
+
+ def test_multiple_init_input_consumers(self) -> None:
+ """Test init_output_run when there are two tasks consuming the same
+ init-input.
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime")
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime")
+ a_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init")
+ a_config.init_outputs["io"] = DynamicConnectionConfig(dataset_type_name="output_init")
+ b_config = DynamicTestPipelineTaskConfig()
+ b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime")
+ b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ b_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init")
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ pipeline_graph.add_task("b", DynamicTestPipelineTask, b_config)
+ with self.prep_butler(pipeline_graph) as butler:
+ self.init_with_pipeline_graph_first(pipeline_graph, butler, "run1")
+ self.assertEqual(butler.get("a_config", collections="run1"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run1"), b_config)
+ self.init_with_quantum_graph_first(pipeline_graph, butler, "run2")
+ self.assertEqual(butler.get("a_config", collections="run2"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run2"), b_config)
+ self.init_with_qbb_first(pipeline_graph, butler, "run3")
+ self.assertEqual(butler.get("a_config", collections="run3"), a_config)
+ self.assertEqual(butler.get("b_config", collections="run3"), b_config)
+
+ def test_config_change(self) -> None:
+ """Test init_output_run when there is an existing config that is
+ inconsistent with the one in the pipeline graph.
+ """
+ a_config = DynamicTestPipelineTaskConfig()
+ a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime")
+ a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime")
+ pipeline_graph = PipelineGraph()
+ pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config)
+ with self.prep_butler(pipeline_graph) as butler:
+ butler.collections.register("run1")
+ butler.put(DynamicTestPipelineTaskConfig(), "a_config", run="run1")
+ with self.assertRaises(ConflictingDefinitionError):
+ pipeline_graph.init_output_run(
+ butler._clone(run="run1", collections=[self.INPUT_COLLECTION, "run1"])
+ )
+ quantum_graph_builder = AllDimensionsQuantumGraphBuilder(
+ pipeline_graph,
+ butler,
+ skip_existing_in=["run1"],
+ output_run="run1",
+ input_collections=[self.INPUT_COLLECTION],
+ )
+ quantum_graph = quantum_graph_builder.build(
+ metadata={"output_run": "run1"}, attach_datastore_records=True
+ )
+ with self.assertRaises(ConflictingDefinitionError):
+ quantum_graph.init_output_run(
+ butler._clone(run="run1", collections=[self.INPUT_COLLECTION, "run1"])
+ )
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()
diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py
index d4935ccd9..1095785db 100644
--- a/tests/test_pipeline_graph.py
+++ b/tests/test_pipeline_graph.py
@@ -121,6 +121,10 @@ def test_unresolved_accessors(self) -> None:
self.assertEqual(
repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)"
)
+ with self.assertRaises(UnresolvedGraphError):
+ self.graph.packages_dataset_type
+ with self.assertRaises(UnresolvedGraphError):
+ self.graph.instantiate_tasks()
def test_sorting(self) -> None:
"""Test sort methods on PipelineGraph."""
@@ -198,6 +202,7 @@ def test_resolved_accessors(self) -> None:
self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty)
self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict")
self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict")
+ self.assertEqual(self.graph.packages_dataset_type.name, acc.PACKAGES_INIT_OUTPUT_NAME)
def test_resolved_xgraph_export(self) -> None:
"""Test exporting a resolved PipelineGraph to networkx in various