diff --git a/doc/changes/DM-38041.feature.md b/doc/changes/DM-38041.feature.md new file mode 100644 index 000000000..e2ef1ae92 --- /dev/null +++ b/doc/changes/DM-38041.feature.md @@ -0,0 +1,3 @@ +Add support for initializing processing output runs with just a pipeline graph, not a quantum graph. + +This also moves much of the logic for initializing output runs from `lsst.ctrl.mpexec.PreExecInit` to `PipelineGraph` and `QuantumGraph` methods. diff --git a/python/lsst/pipe/base/graph/_versionDeserializers.py b/python/lsst/pipe/base/graph/_versionDeserializers.py index 0bcf020f5..908e5d681 100644 --- a/python/lsst/pipe/base/graph/_versionDeserializers.py +++ b/python/lsst/pipe/base/graph/_versionDeserializers.py @@ -545,8 +545,8 @@ def constructGraph( container = {} datasetDict = _DatasetTracker(createInverse=True) taskToQuantumNode: defaultdict[TaskDef, set[QuantumNode]] = defaultdict(set) - initInputRefs: dict[TaskDef, list[DatasetRef]] = {} - initOutputRefs: dict[TaskDef, list[DatasetRef]] = {} + initInputRefs: dict[str, list[DatasetRef]] = {} + initOutputRefs: dict[str, list[DatasetRef]] = {} if universe is not None: if not universe.isCompatibleWith(self.infoMappings.universe): @@ -597,11 +597,11 @@ def constructGraph( # initInputRefs and initOutputRefs are optional if (refs := taskDefDump.get("initInputRefs")) is not None: - initInputRefs[recreatedTaskDef] = [ + initInputRefs[recreatedTaskDef.label] = [ cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs ] if (refs := taskDefDump.get("initOutputRefs")) is not None: - initOutputRefs[recreatedTaskDef] = [ + initOutputRefs[recreatedTaskDef.label] = [ cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs ] diff --git a/python/lsst/pipe/base/graph/graph.py b/python/lsst/pipe/base/graph/graph.py index 3b41989b2..0cce2b3d7 100644 --- a/python/lsst/pipe/base/graph/graph.py +++ b/python/lsst/pipe/base/graph/graph.py @@ -46,22 +46,28 @@ import networkx as nx from lsst.daf.butler import ( + Config, DatasetId, DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, + LimitedButler, Quantum, + QuantumBackedButler, ) +from lsst.daf.butler.datastore.record_data import DatastoreRecordData from lsst.daf.butler.persistence_context import PersistenceContextVars +from lsst.daf.butler.registry import ConflictingDefinitionError from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils.introspection import get_full_type_name from lsst.utils.packages import Packages from networkx.drawing.nx_agraph import write_dot +from ..config import PipelineTaskConfig from ..connections import iterConnections from ..pipeline import TaskDef -from ..pipeline_graph import PipelineGraph +from ..pipeline_graph import PipelineGraph, compare_packages, log_config_mismatch from ._implDetails import DatasetTypeName, _DatasetTracker from ._loadHelpers import LoadHelper from ._versionDeserializers import DESERIALIZER_MAP @@ -286,14 +292,14 @@ def _buildGraphs( # insertion self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) - self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {} - self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {} + self._initInputRefs: dict[str, list[DatasetRef]] = {} + self._initOutputRefs: dict[str, list[DatasetRef]] = {} self._globalInitOutputRefs: list[DatasetRef] = [] self._registryDatasetTypes: list[DatasetType] = [] if initInputs is not None: - self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()} + self._initInputRefs = {taskDef.label: list(refs) for taskDef, refs in initInputs.items()} if initOutputs is not None: - self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()} + self._initOutputRefs = {taskDef.label: list(refs) for taskDef, refs in initOutputs.items()} if globalInitOutputs is not None: self._globalInitOutputRefs = list(globalInitOutputs) if registryDatasetTypes is not None: @@ -812,6 +818,38 @@ def metadata(self) -> MappingProxyType[str, Any]: """ return MappingProxyType(self._metadata) + def get_init_input_refs(self, task_label: str) -> list[DatasetRef]: + """Return the DatasetRefs for the given task's init inputs. + + Parameters + ---------- + task_label : `str` + Label of the task. + + Returns + ------- + refs : `list` [ `lsst.daf.butler.DatasetRef` ] + Dataset references. Guaranteed to be a new list, not internal + state. + """ + return list(self._initInputRefs.get(task_label, ())) + + def get_init_output_refs(self, task_label: str) -> list[DatasetRef]: + """Return the DatasetRefs for the given task's init outputs. + + Parameters + ---------- + task_label : `str` + Label of the task. + + Returns + ------- + refs : `list` [ `lsst.daf.butler.DatasetRef` ] + Dataset references. Guaranteed to be a new list, not internal + state. + """ + return list(self._initOutputRefs.get(task_label, ())) + def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: """Return DatasetRefs for a given task InitInputs. @@ -826,7 +864,7 @@ def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: DatasetRef for the task InitInput, can be `None`. This can return either resolved or non-resolved reference. """ - return self._initInputRefs.get(taskDef) + return self._initInputRefs.get(taskDef.label) def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: """Return DatasetRefs for a given task InitOutputs. @@ -843,7 +881,7 @@ def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: either resolved or non-resolved reference. Resolved reference will match Quantum's initInputs if this is an intermediate dataset type. """ - return self._initOutputRefs.get(taskDef) + return self._initOutputRefs.get(taskDef.label) def globalInitOutputRefs(self) -> list[DatasetRef]: """Return DatasetRefs for global InitOutputs. @@ -1027,9 +1065,9 @@ def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[ taskDef.config.saveToStream(stream) taskDescription["config"] = stream.getvalue() taskDescription["label"] = taskDef.label - if (refs := self._initInputRefs.get(taskDef)) is not None: + if (refs := self._initInputRefs.get(taskDef.label)) is not None: taskDescription["initInputRefs"] = [ref.to_json() for ref in refs] - if (refs := self._initOutputRefs.get(taskDef)) is not None: + if (refs := self._initOutputRefs.get(taskDef.label)) is not None: taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs] inputs = [] @@ -1403,3 +1441,218 @@ def getSummary(self) -> QgraphSummary: qts.numOutputs[k.name] += 1 return summary + + def make_init_qbb( + self, + butler_config: Config | ResourcePathExpression, + *, + config_search_paths: Iterable[str] | None = None, + ) -> QuantumBackedButler: + """Construct an quantum-backed butler suitable for reading and writing + init input and init output datasets, respectively. + + This requires the full graph to have been loaded. + + Parameters + ---------- + butler_config : `~lsst.daf.butler.Config` or \ + `~lsst.resources.ResourcePathExpression` + A butler repository root, configuration filename, or configuration + instance. + config_search_paths : `~collections.abc.Iterable` [ `str` ], optional + Additional search paths for butler configuration. + + Returns + ------- + qbb : `~lsst.daf.butler.QuantumBackedButler` + A limited butler that can ``get`` init-input datasets and ``put`` + init-output datasets. + """ + universe = self.universe + # Collect all init input/output dataset IDs. + predicted_inputs: set[DatasetId] = set() + predicted_outputs: set[DatasetId] = set() + pipeline_graph = self.pipeline_graph + for task_label in pipeline_graph.tasks: + predicted_inputs.update(ref.id for ref in self.get_init_input_refs(task_label)) + predicted_outputs.update(ref.id for ref in self.get_init_output_refs(task_label)) + predicted_outputs.update(ref.id for ref in self.globalInitOutputRefs()) + # remove intermediates from inputs + predicted_inputs -= predicted_outputs + # Very inefficient way to extract datastore records from quantum graph, + # we have to scan all quanta and look at their datastore records. + datastore_records: dict[str, DatastoreRecordData] = {} + for quantum_node in self: + for store_name, records in quantum_node.quantum.datastore_records.items(): + subset = records.subset(predicted_inputs) + if subset is not None: + datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset) + + dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()} + # Make butler from everything. + return QuantumBackedButler.from_predicted( + config=butler_config, + predicted_inputs=predicted_inputs, + predicted_outputs=predicted_outputs, + dimensions=universe, + datastore_records=datastore_records, + search_paths=list(config_search_paths) if config_search_paths is not None else None, + dataset_types=dataset_types, + ) + + def write_init_outputs(self, butler: LimitedButler, skip_existing: bool = True) -> None: + """Write the init-output datasets for all tasks in the quantum graph. + + Parameters + ---------- + butler : `lsst.daf.butler.LimitedButler` + A limited butler data repository client. + skip_existing : `bool`, optional + If `True` (default) ignore init-outputs that already exist. If + `False`, raise. + + Raises + ------ + lsst.daf.butler.registry.ConflictingDefinitionError + Raised if an init-output dataset already exists and + ``skip_existing=False``. + """ + # Extract init-input and init-output refs from the QG. + input_refs: dict[str, DatasetRef] = {} + output_refs: dict[str, DatasetRef] = {} + for task_node in self.pipeline_graph.tasks.values(): + input_refs.update( + {ref.datasetType.name: ref for ref in self.get_init_input_refs(task_node.label)} + ) + output_refs.update( + { + ref.datasetType.name: ref + for ref in self.get_init_output_refs(task_node.label) + if ref.datasetType.name != task_node.init.config_output.dataset_type_name + } + ) + for ref, is_stored in butler.stored_many(output_refs.values()).items(): + if is_stored: + if not skip_existing: + raise ConflictingDefinitionError(f"Init-output dataset {ref} already exists.") + # We'll `put` whatever's left in output_refs at the end. + del output_refs[ref.datasetType.name] + # Instantiate tasks, reading overall init-inputs and gathering + # init-output in-memory objects. + init_outputs: list[tuple[Any, DatasetType]] = [] + self.pipeline_graph.instantiate_tasks( + get_init_input=lambda dataset_type: butler.get( + input_refs[dataset_type.name].overrideStorageClass(dataset_type.storageClass) + ), + init_outputs=init_outputs, + ) + # Write init-outputs that weren't already present. + for obj, dataset_type in init_outputs: + if new_ref := output_refs.get(dataset_type.name): + assert ( + new_ref.datasetType.storageClass_name == dataset_type.storageClass_name + ), "QG init refs should use task connection storage classes." + butler.put(obj, new_ref) + + def write_configs(self, butler: LimitedButler, compare_existing: bool = True) -> None: + """Write the config datasets for all tasks in the quantum graph. + + Parameters + ---------- + butler : `lsst.daf.butler.LimitedButler` + A limited butler data repository client. + compare_existing : `bool`, optional + If `True` check configs that already exist for consistency. If + `False`, always raise if configs already exist. + + Raises + ------ + lsst.daf.butler.registry.ConflictingDefinitionError + Raised if an config dataset already exists and + ``compare_existing=False``, or if the existing config is not + consistent with the config in the quantum graph. + """ + to_put: list[tuple[PipelineTaskConfig, DatasetRef]] = [] + for task_node in self.pipeline_graph.tasks.values(): + dataset_type_name = task_node.init.config_output.dataset_type_name + (ref,) = [ + ref + for ref in self.get_init_output_refs(task_node.label) + if ref.datasetType.name == dataset_type_name + ] + try: + old_config = butler.get(ref) + except (LookupError, FileNotFoundError): + old_config = None + if old_config is not None: + if not compare_existing: + raise ConflictingDefinitionError(f"Config dataset {ref} already exists.") + if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch): + raise ConflictingDefinitionError( + f"Config does not match existing task config {dataset_type_name!r} in " + "butler; tasks configurations must be consistent within the same run collection." + ) + else: + to_put.append((task_node.config, ref)) + # We do writes at the end to minimize the mess we leave behind when we + # raise an exception. + for config, ref in to_put: + butler.put(config, ref) + + def write_packages(self, butler: LimitedButler, compare_existing: bool = True) -> None: + """Write the 'packages' dataset for the currently-active software + versions. + + Parameters + ---------- + butler : `lsst.daf.butler.LimitedButler` + A limited butler data repository client. + compare_existing : `bool`, optional + If `True` check packages that already exist for consistency. If + `False`, always raise if the packages dataset already exists. + + Raises + ------ + lsst.daf.butler.registry.ConflictingDefinitionError + Raised if the packages dataset already exists and is not consistent + with the current packages. + """ + new_packages = Packages.fromSystem() + (ref,) = self.globalInitOutputRefs() + try: + packages = butler.get(ref) + except (LookupError, FileNotFoundError): + packages = None + if packages is not None: + if not compare_existing: + raise ConflictingDefinitionError(f"Packages dataset {ref} already exists.") + if compare_packages(packages, new_packages): + # have to remove existing dataset first; butler has no + # replace option. + butler.pruneDatasets([ref], unstore=True, purge=True) + butler.put(packages, ref) + else: + butler.put(new_packages, ref) + + def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None: + """Initialize a new output RUN collection by writing init-output + datasets (including configs and packages). + + Parameters + ---------- + butler : `lsst.daf.butler.LimitedButler` + A limited butler data repository client. + existing : `bool`, optional + If `True` check or ignore outputs that already exist. If + `False`, always raise if an output dataset already exists. + + Raises + ------ + lsst.daf.butler.registry.ConflictingDefinitionError + Raised if there are existing init output datasets, and either + ``existing=False`` or their contents are not compatible with this + graph. + """ + self.write_configs(butler, compare_existing=existing) + self.write_packages(butler, compare_existing=existing) + self.write_init_outputs(butler, skip_existing=existing) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index f2e7ddf48..e17bdd775 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -26,21 +26,34 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ("PipelineGraph",) +__all__ = ("PipelineGraph", "log_config_mismatch", "compare_packages") import gzip import itertools import json +import logging from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Set from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast import networkx import networkx.algorithms.bipartite import networkx.algorithms.dag -from lsst.daf.butler import DataCoordinate, DataId, DatasetType, DimensionGroup, DimensionUniverse, Registry -from lsst.daf.butler.registry import MissingDatasetTypeError +from lsst.daf.butler import ( + Butler, + DataCoordinate, + DataId, + DatasetRef, + DatasetType, + DimensionGroup, + DimensionUniverse, + MissingDatasetTypeError, +) +from lsst.daf.butler.registry import ConflictingDefinitionError, Registry from lsst.resources import ResourcePath, ResourcePathExpression +from lsst.utils.packages import Packages +from .._dataset_handle import InMemoryDatasetHandle +from ..automatic_connection_constants import PACKAGES_INIT_OUTPUT_NAME, PACKAGES_INIT_OUTPUT_STORAGE_CLASS from ._dataset_types import DatasetTypeNode from ._edges import Edge, ReadEdge, WriteEdge from ._exceptions import ( @@ -65,6 +78,8 @@ _G = TypeVar("_G", bound=networkx.DiGraph | networkx.MultiDiGraph) +_LOG = logging.getLogger("lsst.pipe.base.pipeline_graph") + class PipelineGraph: """A graph representation of fully-configured pipeline. @@ -1524,6 +1539,343 @@ def split_independent(self) -> Iterable[PipelineGraph]: component_subgraph.sort() yield component_subgraph + ########################################################################### + # + # Data repository/collection initialization + # + ########################################################################### + + @property + def packages_dataset_type(self) -> DatasetType: + """The special "packages" dataset type that records software versions. + + This is not associated with a task and hence is + not considered part of the pipeline graph in other respects, but it + does get written with other provenance datasets. + """ + if self.universe is None: + raise UnresolvedGraphError( + "PipelineGraph must be resolved in order to get the packages dataset type." + ) + return DatasetType(PACKAGES_INIT_OUTPUT_NAME, self.universe.empty, PACKAGES_INIT_OUTPUT_STORAGE_CLASS) + + def register_dataset_types(self, butler: Butler, include_packages: bool = True) -> None: + """Register all dataset types in a data repository. + + Parameters + ---------- + butler : `~lsst.daf.butler.Butler` + Data repository client. + include_packages : `bool`, optional + Whether to include the special "packages" dataset type that records + software versions (this is not associated with a task and hence is + not considered part of the pipeline graph in other respects, but it + does get written with other provenance datasets). + """ + dataset_types = [node.dataset_type for node in self.dataset_types.values()] + if include_packages: + dataset_types.append(self.packages_dataset_type) + for dataset_type in dataset_types: + butler.registry.registerDatasetType(dataset_type) + + def check_dataset_type_registrations(self, butler: Butler, include_packages: bool = True) -> None: + """Check that dataset type registrations in a data repository match + the definitions in this pipeline graph. + + Parameters + ---------- + butler : `~lsst.daf.butler.Butler` + Data repository client. + include_packages : `bool`, optional + Whether to include the special "packages" dataset type that records + software versions (this is not associated with a task and hence is + not considered part of the pipeline graph in other respects, but it + does get written with other provenance datasets). + + Raises + ------ + lsst.daf.butler.MissingDatasetTypeError + Raised if one or more non-optional-input or output dataset types in + the pipeline is not registered at all. + lsst.daf.butler.ConflictingDefinitionError + Raised if the definition in the data repository is not identical + to the definition in the pipeline graph. + + Notes + ----- + Note that dataset type definitions that are storage-class-conversion + compatible but not identical are not permitted by these checks, because + the expectation is that these differences are handled by `resolve`, + which makes the pipeline graph use the data repository definitions. + This method is intended to check that none of those definitions have + changed. + """ + dataset_types = [node.dataset_type for node in self.dataset_types.values()] + if include_packages: + dataset_types.append(self.packages_dataset_type) + missing_dataset_types: list[str] = [] + for dataset_type in dataset_types: + try: + expected = butler.registry.getDatasetType(dataset_type.name) + except MissingDatasetTypeError: + expected = None + if expected is None: + # The user probably forgot to register dataset types + # at least once (which should be an error), + # but we could also get here if this is an optional input for + # which no datasets were found in this repo (not an error). + if ( + not ( + self.producer_of(dataset_type.name) is None + and all( + self.tasks[input_edge.task_label].is_optional(input_edge.connection_name) + for input_edge in self.consuming_edges_of(dataset_type.name) + ) + ) + or dataset_type.name == PACKAGES_INIT_OUTPUT_NAME + ): + missing_dataset_types.append(dataset_type.name) + elif expected != dataset_type: + raise ConflictingDefinitionError( + f"DatasetType definition in registry has changed since the pipeline graph was resolved: " + f"{dataset_type} (graph) != {expected} (registry)." + ) + if missing_dataset_types: + plural = "s" if len(missing_dataset_types) != 1 else "" + raise MissingDatasetTypeError( + f"Missing dataset type definition{plural}: {', '.join(missing_dataset_types)}. " + "Dataset types have to be registered in advance (on the command-line, either via " + "`butler register-dataset-type` or the `--register-dataset-types` option to `pipetask run`." + ) + + def instantiate_tasks( + self, + get_init_input: Callable[[DatasetType], Any] | None = None, + init_outputs: list[tuple[Any, DatasetType]] | None = None, + ) -> list[PipelineTask]: + """Instantiate all tasks in the pipeline. + + Parameters + ---------- + get_init_input : `~collections.abc.Callable`, optional + Callable that accepts a single `~lsst.daf.butler.DatasetType` + parameter and returns the init-input dataset associated with that + dataset type. Must respect the storage class embedded in the type. + This is optional if the pipeline does not have any overall init + inputs. When a full butler is available, + `lsst.daf.butler.Butler.get` can be used directly here. + init_outputs : `list`, optional + A list of ``(obj, dataset type)`` init-output dataset pairs, to be + appended to in-place. Both the object and the dataset type will + correspond to the storage class of the output connection, which + may not be the same as the storage class on the graph's dataset + type node. + + Returns + ------- + tasks : `list` + Constructed `PipelineTask` instances. + """ + if not self.is_fully_resolved: + raise UnresolvedGraphError("Pipeline graph must be fully resolved before instantiating tasks.") + empty_data_id = DataCoordinate.make_empty(cast(DimensionUniverse, self.universe)) + handles: dict[str, InMemoryDatasetHandle] = {} + tasks: list[PipelineTask] = [] + for task_node in self.tasks.values(): + task_init_inputs: dict[str, Any] = {} + for read_edge in task_node.init.inputs.values(): + if (handle := handles.get(read_edge.dataset_type_name)) is not None: + obj = handle.get(storageClass=read_edge.storage_class_name) + elif ( + read_edge.component is not None + and (parent_handle := handles.get(read_edge.parent_dataset_type_name)) is not None + ): + obj = parent_handle.get( + storageClass=read_edge.storage_class_name, component=read_edge.component + ) + else: + dataset_type_node = self.dataset_types[read_edge.parent_dataset_type_name] + if get_init_input is None: + raise ValueError( + f"Task {task_node.label!r} requires init-input " + f"{read_edge.dataset_type_name} but no 'get_init_input' callback was provided." + ) + obj = get_init_input(read_edge.adapt_dataset_type(dataset_type_node.dataset_type)) + n_consumers = len(self.consumers_of(dataset_type_node.name)) + if ( + n_consumers > 1 + and read_edge.component is None + and read_edge.storage_class_name == dataset_type_node.storage_class_name + ): + # Caching what we just got is safe in general only + # if there was no storage class conversion, since + # a->b and a->c does not imply b->c. + handles[read_edge.dataset_type_name] = InMemoryDatasetHandle( + obj, + storageClass=dataset_type_node.storage_class, + dataId=empty_data_id, + copy=True, + ) + task_init_inputs[read_edge.connection_name] = obj + task = task_node.task_class( + config=task_node.config, initInputs=task_init_inputs, name=task_node.label + ) + tasks.append(task) + for write_edge in task_node.init.outputs.values(): + dataset_type_node = self.dataset_types[write_edge.parent_dataset_type_name] + obj = getattr(task, write_edge.connection_name) + # We don't immediately coerce obj to the dataset_type_node + # storage class (which should be the repo storage class, if + # there is one) when appending to `init_outputs` because a + # formatter might be able to do a better job of that later; + # instead we pair it with a dataset type that's consistent with + # the in-memory type. We do coerce when populating `handles`, + # though, because going through the dataset_type_node storage + # class is the conversion path we checked when we resolved the + # pipeline graph. + if init_outputs is not None: + init_outputs.append((obj, write_edge.adapt_dataset_type(dataset_type_node.dataset_type))) + n_consumers = len(self.consumers_of(dataset_type_node.name)) + if n_consumers > 0: + handles[dataset_type_node.name] = InMemoryDatasetHandle( + dataset_type_node.storage_class.coerce_type(obj), + dataId=empty_data_id, + storageClass=dataset_type_node.storage_class, + copy=(n_consumers > 1), + ) + return tasks + + def write_init_outputs(self, butler: Butler) -> None: + """Write the init-output datasets for all tasks in the pipeline graph. + + Parameters + ---------- + butler : `lsst.daf.butler.Butler` + A full butler data repository client with its default run set + to the collection where datasets should be written. + + Notes + ----- + Datasets that already exist in the butler's output run collection will + not be written. + + This method writes outputs with new random dataset IDs and should + hence only be used when writing init-outputs prior to building a + `QuantumGraph`. Use `QuantumGraph.write_init_outputs` if a quantum + graph has already been built. + """ + init_outputs: list[tuple[Any, DatasetType]] = [] + self.instantiate_tasks(butler.get, init_outputs) + found_refs: dict[str, DatasetRef] = {} + to_put: list[tuple[Any, DatasetType]] = [] + for obj, dataset_type in init_outputs: + if (ref := butler.find_dataset(dataset_type, collections=butler.run)) is not None: + found_refs[dataset_type.name] = ref + else: + to_put.append((obj, dataset_type)) + for ref, stored in butler.stored_many(found_refs.values()).items(): + if not stored: + raise FileNotFoundError( + f"Init-output dataset {ref.datasetType.name!r} was found in RUN {ref.run!r} " + f"but had not actually been stored (or was stored and later deleted)." + ) + for obj, dataset_type in to_put: + butler.put(obj, dataset_type) + + def write_configs(self, butler: Butler) -> None: + """Write the config datasets for all tasks in the pipeline graph. + + Parameters + ---------- + butler : `lsst.daf.butler.Butler` + A full butler data repository client with its default run set + to the collection where datasets should be written. + + Notes + ----- + Config datasets that already exist in the butler's output run + collection will be checked for consistency. + + This method writes outputs with new random dataset IDs and should + hence only be used when writing init-outputs prior to building a + `QuantumGraph`. Use `QuantumGraph.write_configs` if a quantum graph + has already been built. + + Raises + ------ + lsst.daf.butler.registry.ConflictingDefinitionError + Raised if a config dataset already exists and is not consistent + with the config in the pipeline graph. + """ + to_put: list[tuple[PipelineTaskConfig, str]] = [] + for task_node in self.tasks.values(): + dataset_type_name = task_node.init.config_output.dataset_type_name + if (ref := butler.find_dataset(dataset_type_name, collections=butler.run)) is not None: + old_config = butler.get(ref) + if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch): + raise ConflictingDefinitionError( + f"Config does not match existing task config {dataset_type_name!r} in " + "butler; tasks configurations must be consistent within the same run collection" + ) + else: + to_put.append((task_node.config, dataset_type_name)) + # We do writes at the end to minimize the mess we leave behind when we + # raise an exception. + for config, dataset_type_name in to_put: + butler.put(config, dataset_type_name) + + def write_packages(self, butler: Butler) -> None: + """Write the 'packages' dataset for the currently-active software + versions. + + Parameters + ---------- + butler : `lsst.daf.butler.Butler` + A full butler data repository client with its default run set + to the collection where datasets should be written. + + Notes + ----- + If the packages dataset already exists, it will be compared to the + versions in the current packages. New packages that weren't present + before are not considered an inconsistency. + + This method writes outputs with new random dataset IDs and should + hence only be used when writing init-outputs prior to building a + `QuantumGraph`. Use `QuantumGraph.write_packages` if a quantum graph + has already been built. + + Raises + ------ + lsst.daf.butler.registry.ConflictingDefinitionError + Raised if the packages dataset already exists and is not consistent + with the current packages. + """ + new_packages = Packages.fromSystem() + if (ref := butler.find_dataset(self.packages_dataset_type)) is not None: + packages = butler.get(ref) + if compare_packages(packages, new_packages): + # have to remove existing dataset first; butler has no + # replace option. + butler.pruneDatasets([ref], unstore=True, purge=True) + butler.put(packages, ref) + else: + butler.put(new_packages, self.packages_dataset_type) + + def init_output_run(self, butler: Butler) -> None: + """Initialize a new output RUN collection by writing init-output + datasets (including configs and packages). + + Parameters + ---------- + butler : `lsst.daf.butler.Butler` + A full butler data repository client with its default run set + to the collection where datasets should be written. + """ + self.write_configs(butler) + self.write_packages(butler) + self.write_init_outputs(butler) + ########################################################################### # # Class- and Package-Private Methods. @@ -1818,3 +2170,51 @@ def _reset(self) -> None: _dataset_types: DatasetTypeMappingView _raw_data_id: dict[str, Any] _universe: DimensionUniverse | None + + +def log_config_mismatch(msg: str) -> None: + """Log messages about configuration mismatch. + + Parameters + ---------- + msg : `str` + Log message to use. + """ + _LOG.fatal("Comparing configuration: %s", msg) + + +def compare_packages(packages: Packages, new_packages: Packages) -> bool: + """Compare two versions of Packages. + + Parameters + ---------- + packages : `Packages` + Previously recorded package versions. Updated in place to include + any new packages that weren't present before. + new_packages : `Packages` + New set of package versions. + + Returns + ------- + updated : `bool` + `True` if ``packages`` was updated, `False` if not. + + Raises + ------ + ConflictingDefinitionError + Raised if versions are inconsistent. + """ + diff = new_packages.difference(packages) + if diff: + versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff) + raise ConflictingDefinitionError(f"Package versions mismatch: ({versions_str})") + else: + _LOG.debug("new packages are consistent with old") + # Update the old set of packages in case we have more packages + # that haven't been persisted. + extra = new_packages.extra(packages) + if extra: + _LOG.debug("extra packages: %s", extra) + packages.update(new_packages) + return True + return False diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py index 576869bf4..5c462c5e2 100644 --- a/python/lsst/pipe/base/pipeline_graph/_tasks.py +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -48,7 +48,7 @@ from .. import automatic_connection_constants as acc from ..connections import PipelineTaskConnections -from ..connectionTypes import BaseConnection, InitOutput, Output +from ..connectionTypes import BaseConnection, BaseInput, InitOutput, Output from ._edges import Edge, ReadEdge, WriteEdge from ._exceptions import TaskNotImportedError, UnresolvedGraphError from ._nodes import NodeKey, NodeType @@ -216,7 +216,7 @@ class TaskInitNode: - ``task_class_name`` - ``bipartite`` (see `NodeType.bipartite`) - ``task_class`` (only if `is_imported` is `True`) - - ``config`` (only if `is_importd` is `True`) + - ``config`` (only if `is_imported` is `True`) """ def __init__( @@ -798,6 +798,23 @@ def get_lookup_function( """ return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None) + def is_optional(self, connection_name: str) -> bool: + """Check whether the given connection has ``minimum==0``. + + Parameters + ---------- + connection_name : `str` + Name of the connection. + + Returns + ------- + optional : `bool` + Whether this task can run without any datasets for the given + connection. + """ + connection = getattr(self.get_connections(), connection_name) + return isinstance(connection, BaseInput) and connection.minimum == 0 + def get_connections(self) -> PipelineTaskConnections: """Return the connections class instance for this task. diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index 0c2d7ca23..70931e4af 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -563,7 +563,7 @@ class DynamicConnectionConfig(Config): default=True, ) minimum = Field[int]( - doc="Minimum number of datasets per quantum requried for this connection. Ignored for non-inputs.", + doc="Minimum number of datasets per quantum required for this connection. Ignored for non-inputs.", dtype=int, default=1, ) diff --git a/tests/test_init_output_run.py b/tests/test_init_output_run.py new file mode 100644 index 000000000..2cc6341e1 --- /dev/null +++ b/tests/test_init_output_run.py @@ -0,0 +1,535 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import itertools +import tempfile +import unittest +from collections.abc import Iterator +from contextlib import contextmanager +from typing import ClassVar + +import lsst.utils.tests +from lsst.daf.butler import ( + Butler, + DatasetRef, + DatasetType, + MissingDatasetTypeError, + QuantumBackedButler, + SerializedDatasetType, + StorageClassFactory, +) +from lsst.daf.butler.registry import ConflictingDefinitionError +from lsst.pipe.base import QuantumGraph +from lsst.pipe.base.all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder +from lsst.pipe.base.pipeline_graph import PipelineGraph +from lsst.pipe.base.tests.mocks import ( + DynamicConnectionConfig, + DynamicTestPipelineTask, + DynamicTestPipelineTaskConfig, + MockDataset, +) + + +def _have_example_storage_classes() -> bool: + """Check whether some storage classes work as expected. + + Given that these have registered converters, it shouldn't actually be + necessary to import those types in order to determine that they're + convertible, but the storage class machinery is implemented such that types + that can't be imported can't be converted, and while that's inconvenient + here it's totally fine in non-testing scenarios where you only care about a + storage class if you can actually use it. + """ + getter = StorageClassFactory().getStorageClass + return ( + getter("ArrowTable").can_convert(getter("ArrowAstropy")) + and getter("ArrowAstropy").can_convert(getter("ArrowTable")) + and getter("ArrowTable").can_convert(getter("DataFrame")) + and getter("DataFrame").can_convert(getter("ArrowTable")) + ) + + +class InitOutputRunTestCase(unittest.TestCase): + """Tests for the init_output_run methods of PipelineGraph and + QuantumGraph. + """ + + INPUT_COLLECTION: ClassVar[str] = "overall_inputs" + + @contextmanager + def make_butler(self) -> Iterator[Butler]: + """Wrap a temporary local butler repository in a context manager.""" + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as root: + Butler.makeRepo(root) + butler = Butler.from_config(root, writeable=True) + yield butler + + @contextmanager + def prep_butler(self, pipeline_graph: PipelineGraph) -> Iterator[Butler]: + """Create a temporary local butler repository with the dataset types + and input datasets needed by a pipeline graph. + + This also resolves the pipeline graph and checks dataset types + immediately after they are registered, providing test coverage for the + methods that do that. + """ + with self.make_butler() as butler: + butler.collections.register(self.INPUT_COLLECTION) + pipeline_graph.resolve(butler.registry) + with self.assertRaises(MissingDatasetTypeError): + pipeline_graph.check_dataset_type_registrations(butler) + pipeline_graph.register_dataset_types(butler) + pipeline_graph.check_dataset_type_registrations(butler) + for _, dataset_type_node in pipeline_graph.iter_overall_inputs(): + butler.put( + MockDataset( + dataset_id=None, + dataset_type=SerializedDatasetType( + name=dataset_type_node.name, + dimensions=[], + storageClass=dataset_type_node.storage_class_name, + ), + data_id={}, + run=self.INPUT_COLLECTION, + ), + dataset_type_node.name, + run=self.INPUT_COLLECTION, + ) + yield butler + + def find_init_output_refs( + self, pipeline_graph: PipelineGraph, butler: Butler + ) -> dict[str, list[DatasetRef]]: + """Find the init-output datasets of a pipeline graph in a butler + repository. + + Parameters + ---------- + pipeline_graph : `PipelineGraph` + Pipeline graph. + butler : `Butler` + Full butler client. + + Returns + ------- + init_output_refs : `dict` + Dataset references, keyed by task label. Storage classes will + match the data repository definitions of the dataset types. The + special 'packages' dataset type will be included under a '*' key. + """ + init_output_refs: dict[str, list[DatasetRef]] = {} + for task_node in pipeline_graph.tasks.values(): + init_output_refs_for_task: list[DatasetRef] = [] + for write_edge in task_node.init.iter_all_outputs(): + ref = butler.find_dataset(write_edge.dataset_type_name) + # Check that the ref we got back uses the dataset type node's + # definition of the dataset type (including storage class). + self.assertEqual( + ref.datasetType, pipeline_graph.dataset_types[write_edge.dataset_type_name].dataset_type + ) + # Remember the version of the ref that has the task's storage + # class, in case they differ. + init_output_refs_for_task.append(write_edge.adapt_dataset_ref(ref)) + init_output_refs[task_node.label] = init_output_refs_for_task + init_output_refs["*"] = [butler.find_dataset(pipeline_graph.packages_dataset_type)] + return init_output_refs + + def get_quantum_graph_init_output_refs(self, quantum_graph: QuantumGraph) -> dict[str, list[DatasetRef]]: + """Extract dataset references from a QuantumGraph into the same form + as returned by `find_init_output_refs`. + """ + init_output_refs: dict[str, list[DatasetRef]] = {} + for task_label in quantum_graph.pipeline_graph.tasks: + init_output_refs[task_label] = quantum_graph.get_init_output_refs(task_label) + init_output_refs["*"] = list(quantum_graph.globalInitOutputRefs()) + return init_output_refs + + def assert_init_output_refs_equal( + self, a: dict[str, list[DatasetRef]], b: dict[str, list[DatasetRef]] + ) -> None: + """Check that two dictionaries of the form returned by + `find_init_output_refs` are equal. + """ + self.assertEqual(a.keys(), b.keys()) + for task_label, init_output_refs_for_task in a.items(): + self.assertCountEqual(init_output_refs_for_task, b[task_label]) + + def check_qbb_consistency( + self, init_output_refs: dict[str, list[DatasetRef]], qbb: QuantumBackedButler + ) -> None: + """Check that a quantum-backed butler sees all of the given datasets. + + Parameters + ---------- + init_output_refs : `dict` + Dataset references, keyed by task label. Storage classes should + match the data repository definitions of the dataset types. The + special 'packages' dataset type should be included under a '*' key. + qbb : `lsst.daf.butler.QuantumBackedButler` + A quantum-backed butler. + """ + for task_label, init_output_refs_for_task in init_output_refs.items(): + for ref, stored_in in qbb.stored_many(init_output_refs_for_task).items(): + self.assertTrue( + stored_in, msg=f"Init-input {ref} of task {task_label} not stored according to QBB." + ) + + def init_with_pipeline_graph_first( + self, pipeline_graph: PipelineGraph, butler: Butler, run: str + ) -> QuantumGraph: + """Test the init_output_run methods of PipelineGraph and QuantumGraph, + using the former to actually write init-outputs (with later attempts + correctly failing or doing nothing, depending on parameters). + """ + butler = butler._clone(run=run, collections=[self.INPUT_COLLECTION, run]) + pipeline_graph.init_output_run(butler) + init_output_refs = self.find_init_output_refs(pipeline_graph, butler) + # Build a QG with the init outputs already in place. + quantum_graph_builder = AllDimensionsQuantumGraphBuilder( + pipeline_graph, + butler, + skip_existing_in=[run], + output_run=run, + input_collections=[self.INPUT_COLLECTION], + ) + quantum_graph = quantum_graph_builder.build( + metadata={"output_run": run}, attach_datastore_records=True + ) + # Check that the QG refs are the same as the ones that were present + # already. + self.assert_init_output_refs_equal( + self.get_quantum_graph_init_output_refs(quantum_graph), + init_output_refs, + ) + # Initialize with the pipeline graph, should be a no-op. + pipeline_graph.init_output_run(butler) + self.assert_init_output_refs_equal( + self.find_init_output_refs(pipeline_graph, butler), + init_output_refs, + ) + # Initialize again with the QG; should be a no-op. + quantum_graph.init_output_run(butler) + self.assert_init_output_refs_equal( + self.find_init_output_refs(pipeline_graph, butler), + init_output_refs, + ) + # Initialize again with the QG but tell it to expect an empty run. + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run(butler, existing=False) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.write_configs(butler, compare_existing=False) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.write_packages(butler, compare_existing=False) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.write_init_outputs(butler, skip_existing=False) + # Make a QBB, check that it can see the init outputs. + qbb = quantum_graph.make_init_qbb(butler._config) + self.check_qbb_consistency(init_output_refs, qbb) + # Use QBB to initialize again, should be a no-op. + quantum_graph.init_output_run(qbb) + self.check_qbb_consistency(init_output_refs, qbb) + # Use QBB to initialize again but tell it to expect an empty run. + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run(qbb, existing=False) + return quantum_graph + + def init_with_quantum_graph_first( + self, pipeline_graph: PipelineGraph, butler: Butler, run: str + ) -> QuantumGraph: + """Test the init_output_run methods of PipelineGraph and QuantumGraph, + using the latter to actually write init-outputs (with later attempts + correctly failing or doing nothing, depending on parameters). + """ + butler = butler._clone(run=run, collections=[self.INPUT_COLLECTION, run]) + # Build a QG. + quantum_graph_builder = AllDimensionsQuantumGraphBuilder( + pipeline_graph, + butler, + input_collections=[self.INPUT_COLLECTION], + ) + quantum_graph = quantum_graph_builder.build( + metadata={"output_run": run}, attach_datastore_records=True + ) + # Initialize with the QG. + quantum_graph.init_output_run(butler) + # Check that the QG refs are the same as the ones we find in the repo. + init_output_refs = self.find_init_output_refs(pipeline_graph, butler) + self.assert_init_output_refs_equal( + self.get_quantum_graph_init_output_refs(quantum_graph), + init_output_refs, + ) + # Initialize again with the QG; should be a no-op. + quantum_graph.init_output_run(butler) + self.assert_init_output_refs_equal( + self.find_init_output_refs(pipeline_graph, butler), + init_output_refs, + ) + # Initialize again with the QG but tell it to expect an empty run. + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run(butler, existing=False) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.write_configs(butler, compare_existing=False) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.write_packages(butler, compare_existing=False) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.write_init_outputs(butler, skip_existing=False) + # Initialize with the pipeline graph, should be a no-op. + pipeline_graph.init_output_run(butler) + # Make a QBB, check that it can see the init outputs. + qbb = quantum_graph.make_init_qbb(butler._config) + self.check_qbb_consistency(init_output_refs, qbb) + # Use QBB to initialize again, should be a no-op. + quantum_graph.init_output_run(qbb) + self.check_qbb_consistency(init_output_refs, qbb) + # Use QBB to initialize again but tell it to expect an empty run. + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run(qbb, existing=False) + return quantum_graph + + def init_with_qbb_first(self, pipeline_graph: PipelineGraph, butler: Butler, run: str) -> QuantumGraph: + """Test the init_output_run methods of PipelineGraph and QuantumGraph, + using the latter a quantum-backed butler to actually write init-outputs + (with later attempts correctly failing or doing nothing, depending on + parameters). + """ + butler = butler._clone(run=run, collections=[self.INPUT_COLLECTION, run]) + # Build a QG. + quantum_graph_builder = AllDimensionsQuantumGraphBuilder( + pipeline_graph, + butler, + input_collections=[self.INPUT_COLLECTION], + ) + quantum_graph = quantum_graph_builder.build( + metadata={"output_run": run}, attach_datastore_records=True + ) + # Make a quantum-backed butler and use it to initialize the run. + qbb = quantum_graph.make_init_qbb(butler._config) + quantum_graph.init_output_run(qbb) + init_output_refs = self.get_quantum_graph_init_output_refs(quantum_graph) + self.check_qbb_consistency(init_output_refs, qbb) + # Use QBB to initialize again, should be a no-op. + self.check_qbb_consistency(init_output_refs, qbb) + # Use QBB to initialize again but tell it to expect an empty run. + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run(qbb, existing=False) + # Transferring datasets back to the main butler (i.e. insert DB entries + # for them). + butler.transfer_from(qbb, itertools.chain.from_iterable(init_output_refs.values())) + # Check that the QG refs are the same as the ones we find in the repo. + self.assert_init_output_refs_equal( + self.find_init_output_refs(pipeline_graph, butler), + init_output_refs, + ) + # Initialize again with the QG; should be a no-op. + quantum_graph.init_output_run(butler) + self.assert_init_output_refs_equal( + self.find_init_output_refs(pipeline_graph, butler), + init_output_refs, + ) + # Initialize again with the QG but tell it to expect an empty run. + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run(butler, existing=False) + # Initialize with the pipeline graph, should be a no-op. + pipeline_graph.init_output_run(butler) + return quantum_graph + + def test_two_tasks_no_conversions(self) -> None: + """Test a two-task pipeline with an overall init-input, an overall + init-output, and an init-intermediate. + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime") + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime") + a_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init") + a_config.init_outputs["io"] = DynamicConnectionConfig(dataset_type_name="intermediate_init") + b_config = DynamicTestPipelineTaskConfig() + b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime") + b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + b_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="intermediate_init") + b_config.init_outputs["io"] = DynamicConnectionConfig(dataset_type_name="output_init") + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + pipeline_graph.add_task("b", DynamicTestPipelineTask, b_config) + with self.prep_butler(pipeline_graph) as butler: + self.init_with_pipeline_graph_first(pipeline_graph, butler, "run1") + self.assertEqual(butler.get("a_config", collections="run1"), a_config) + self.assertEqual(butler.get("b_config", collections="run1"), b_config) + self.init_with_quantum_graph_first(pipeline_graph, butler, "run2") + self.assertEqual(butler.get("a_config", collections="run2"), a_config) + self.assertEqual(butler.get("b_config", collections="run2"), b_config) + self.init_with_qbb_first(pipeline_graph, butler, "run3") + self.assertEqual(butler.get("a_config", collections="run3"), a_config) + self.assertEqual(butler.get("b_config", collections="run3"), b_config) + + def test_optional_input_unregistered(self) -> None: + """Test that an optional input dataset type that is not registered is + not considered an error. + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime", minimum=0) + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + with self.make_butler() as butler: + pipeline_graph.resolve(butler.registry) + butler.registry.registerDatasetType(pipeline_graph.dataset_types["a_config"].dataset_type) + butler.registry.registerDatasetType(pipeline_graph.dataset_types["a_log"].dataset_type) + butler.registry.registerDatasetType(pipeline_graph.dataset_types["a_metadata"].dataset_type) + butler.registry.registerDatasetType(pipeline_graph.dataset_types["output_runtime"].dataset_type) + pipeline_graph.check_dataset_type_registrations(butler, include_packages=False) + + def test_registration_changed(self) -> None: + """Test that we get an error when dataset type registrations in a data + repository change between the time a pipeline graph is resolved (e.g. + at QG generation) and when dataset types are checked later (e.g. during + execution). + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime") + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + with self.make_butler() as butler: + pipeline_graph.resolve(butler.registry) + pipeline_graph.register_dataset_types(butler) + butler.registry.removeDatasetType("input_runtime") + butler.registry.registerDatasetType( + DatasetType("input_runtime", {"instrument"}, "StructuredDataList", universe=butler.dimensions) + ) + with self.assertRaises(ConflictingDefinitionError): + pipeline_graph.check_dataset_type_registrations(butler) + + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) + def test_init_intermediate_component(self) -> None: + """Test init_output_run with an init-intermediate that is written as + a composite and read as a component. + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime") + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime") + a_config.init_outputs["io"] = DynamicConnectionConfig( + dataset_type_name="intermediate_init", storage_class="ArrowTable" + ) + b_config = DynamicTestPipelineTaskConfig() + b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime") + b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + b_config.init_inputs["ii"] = DynamicConnectionConfig( + dataset_type_name="intermediate_init.schema", storage_class="ArrowSchema" + ) + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + pipeline_graph.add_task("b", DynamicTestPipelineTask, b_config) + with self.prep_butler(pipeline_graph) as butler: + self.init_with_pipeline_graph_first(pipeline_graph, butler, "run1") + self.assertEqual(butler.get("a_config", collections="run1"), a_config) + self.assertEqual(butler.get("b_config", collections="run1"), b_config) + self.init_with_quantum_graph_first(pipeline_graph, butler, "run2") + self.assertEqual(butler.get("a_config", collections="run2"), a_config) + self.assertEqual(butler.get("b_config", collections="run2"), b_config) + self.init_with_qbb_first(pipeline_graph, butler, "run3") + self.assertEqual(butler.get("a_config", collections="run3"), a_config) + self.assertEqual(butler.get("b_config", collections="run3"), b_config) + + def test_no_get_init_input_callback(self) -> None: + """Test calling PipelineGraph.instantiate_tasks with no get_init_input + callback when one is necessary. + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime") + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + a_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init") + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + with self.make_butler() as butler: + pipeline_graph.resolve(butler.registry) + with self.assertRaises(ValueError): + pipeline_graph.instantiate_tasks() + + def test_multiple_init_input_consumers(self) -> None: + """Test init_output_run when there are two tasks consuming the same + init-input. + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime") + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime") + a_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init") + a_config.init_outputs["io"] = DynamicConnectionConfig(dataset_type_name="output_init") + b_config = DynamicTestPipelineTaskConfig() + b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="intermediate_runtime") + b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + b_config.init_inputs["ii"] = DynamicConnectionConfig(dataset_type_name="input_init") + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + pipeline_graph.add_task("b", DynamicTestPipelineTask, b_config) + with self.prep_butler(pipeline_graph) as butler: + self.init_with_pipeline_graph_first(pipeline_graph, butler, "run1") + self.assertEqual(butler.get("a_config", collections="run1"), a_config) + self.assertEqual(butler.get("b_config", collections="run1"), b_config) + self.init_with_quantum_graph_first(pipeline_graph, butler, "run2") + self.assertEqual(butler.get("a_config", collections="run2"), a_config) + self.assertEqual(butler.get("b_config", collections="run2"), b_config) + self.init_with_qbb_first(pipeline_graph, butler, "run3") + self.assertEqual(butler.get("a_config", collections="run3"), a_config) + self.assertEqual(butler.get("b_config", collections="run3"), b_config) + + def test_config_change(self) -> None: + """Test init_output_run when there is an existing config that is + inconsistent with the one in the pipeline graph. + """ + a_config = DynamicTestPipelineTaskConfig() + a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="input_runtime") + a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="output_runtime") + pipeline_graph = PipelineGraph() + pipeline_graph.add_task("a", DynamicTestPipelineTask, a_config) + with self.prep_butler(pipeline_graph) as butler: + butler.collections.register("run1") + butler.put(DynamicTestPipelineTaskConfig(), "a_config", run="run1") + with self.assertRaises(ConflictingDefinitionError): + pipeline_graph.init_output_run( + butler._clone(run="run1", collections=[self.INPUT_COLLECTION, "run1"]) + ) + quantum_graph_builder = AllDimensionsQuantumGraphBuilder( + pipeline_graph, + butler, + skip_existing_in=["run1"], + output_run="run1", + input_collections=[self.INPUT_COLLECTION], + ) + quantum_graph = quantum_graph_builder.build( + metadata={"output_run": "run1"}, attach_datastore_records=True + ) + with self.assertRaises(ConflictingDefinitionError): + quantum_graph.init_output_run( + butler._clone(run="run1", collections=[self.INPUT_COLLECTION, "run1"]) + ) + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index d4935ccd9..1095785db 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -121,6 +121,10 @@ def test_unresolved_accessors(self) -> None: self.assertEqual( repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)" ) + with self.assertRaises(UnresolvedGraphError): + self.graph.packages_dataset_type + with self.assertRaises(UnresolvedGraphError): + self.graph.instantiate_tasks() def test_sorting(self) -> None: """Test sort methods on PipelineGraph.""" @@ -198,6 +202,7 @@ def test_resolved_accessors(self) -> None: self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty) self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict") self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict") + self.assertEqual(self.graph.packages_dataset_type.name, acc.PACKAGES_INIT_OUTPUT_NAME) def test_resolved_xgraph_export(self) -> None: """Test exporting a resolved PipelineGraph to networkx in various