Skip to content

Commit

Permalink
Merge pull request #363 from lsst/tickets/DM-40303
Browse files Browse the repository at this point in the history
DM-40303: Fix pydantic v2 warnings
  • Loading branch information
timj authored Aug 5, 2023
2 parents d81147c + 67e65e3 commit d1bbb55
Show file tree
Hide file tree
Showing 30 changed files with 90 additions and 111 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ repos:
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.10
language_version: python3.11
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.278
rev: v0.0.282
hooks:
- id: ruff
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
# and order them consistently (note that consistent ordering is not
# automatic).
adjusted_inputs = {}
for name, refs in zip(input_names, inputs_by_data_id):
for name, refs in zip(input_names, inputs_by_data_id, strict=True):
adjusted_inputs[name] = (
inputs[name][0],
[refs[data_id] for data_id in data_ids_to_keep],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
# and order them consistently (note that consistent ordering is not
# automatic).
adjusted_inputs = {}
for name, refs in zip(input_names, inputs_by_data_id):
for name, refs in zip(input_names, inputs_by_data_id, strict=True):
adjusted_inputs[name] = (
inputs[name][0],
[refs[data_id] for data_id in data_ids_to_keep],
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ version = { attr = "lsst_versions.get_lsst_version" }

[tool.black]
line-length = 110
target-version = ["py310"]
target-version = ["py311"]

[tool.isort]
profile = "black"
Expand Down Expand Up @@ -163,7 +163,7 @@ select = [
"W", # pycodestyle
"D", # pydocstyle
]
target-version = "py310"
target-version = "py311"
extend-select = [
"RUF100", # Warn about unused noqa
]
Expand Down
5 changes: 2 additions & 3 deletions python/lsst/pipe/base/_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

__all__ = ("Instrument",)

import contextlib
import datetime
import os.path
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -325,10 +326,8 @@ def importAll(registry: Registry) -> None:
records = list(registry.queryDimensionRecords("instrument"))
for record in records:
cls = record.class_name
try:
with contextlib.suppress(Exception):
doImportType(cls)
except Exception:
pass

@abstractmethod
def getRawFormatter(self, dataId: DataId) -> type[Formatter]:
Expand Down
12 changes: 6 additions & 6 deletions python/lsst/pipe/base/_quantumContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get(
n_connections = len(dataset)
n_retrieved = 0
for i, (name, ref) in enumerate(dataset):
if isinstance(ref, (list, tuple)):
if isinstance(ref, list | tuple):
val = []
n_refs = len(ref)
for j, r in enumerate(ref):
Expand Down Expand Up @@ -302,7 +302,7 @@ def get(
"Completed retrieval of %d datasets from %d connections", n_retrieved, n_connections
)
return retVal
elif isinstance(dataset, (list, tuple)):
elif isinstance(dataset, list | tuple):
n_datasets = len(dataset)
retrieved = []
for i, x in enumerate(dataset):
Expand All @@ -314,7 +314,7 @@ def get(
if periodic.num_issued > 0:
_LOG.verbose("Completed retrieval of %d datasets", n_datasets)
return retrieved
elif isinstance(dataset, DatasetRef) or isinstance(dataset, DeferredDatasetRef) or dataset is None:
elif isinstance(dataset, DatasetRef | DeferredDatasetRef) or dataset is None:
return self._get(dataset)
else:
raise TypeError(
Expand Down Expand Up @@ -364,14 +364,14 @@ def put(
)
for name, refs in dataset:
valuesAttribute = getattr(values, name)
if isinstance(refs, (list, tuple)):
if isinstance(refs, list | tuple):
if len(refs) != len(valuesAttribute):
raise ValueError(f"There must be a object to put for every Dataset ref in {name}")
for i, ref in enumerate(refs):
self._put(valuesAttribute[i], ref)
else:
self._put(valuesAttribute, refs)
elif isinstance(dataset, (list, tuple)):
elif isinstance(dataset, list | tuple):
if not isinstance(values, Sequence):
raise ValueError("Values to put must be a sequence")
if len(dataset) != len(values):
Expand Down Expand Up @@ -402,7 +402,7 @@ def _checkMembership(self, ref: list[DatasetRef] | DatasetRef, inout: set) -> No
which may be important for Quanta with lots of
`~lsst.daf.butler.DatasetRef`.
"""
if not isinstance(ref, (list, tuple)):
if not isinstance(ref, list | tuple):
ref = [ref]
for r in ref:
if (r.datasetType, r.dataId) not in inout:
Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/_task_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def keys(self) -> tuple[str, ...]:

def items(self) -> Iterator[tuple[str, Any]]:
"""Yield the top-level keys and values."""
for k, v in itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items()):
yield (k, v)
yield from itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items())

def __len__(self) -> int:
"""Return the number of items."""
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _validateValue(self, value: Any) -> None:
if value is None:
return

if not (isinstance(value, str) or isinstance(value, Number)):
if not (isinstance(value, str | Number)):
raise TypeError(
f"Value {value} is of incorrect type {pexConfig.config._typeStr(value)}."
" Expected type str or a number"
Expand Down
16 changes: 7 additions & 9 deletions python/lsst/pipe/base/configOverrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
__all__ = ["ConfigOverrides"]

import ast
import contextlib
import inspect
from enum import Enum
from operator import attrgetter
Expand Down Expand Up @@ -109,7 +110,7 @@ def visit_Constant(self, node):

def visit_Dict(self, node):
"""Build dict out of component nodes if dict node encountered."""
return {self.visit(key): self.visit(value) for key, value in zip(node.keys, node.values)}
return {self.visit(key): self.visit(value) for key, value in zip(node.keys, node.values, strict=True)}

def visit_Set(self, node):
"""Build set out of node is set encountered."""
Expand Down Expand Up @@ -237,15 +238,12 @@ def addInstrumentOverride(self, instrument: Instrument, task_name: str) -> None:
self._overrides.append((OverrideTypes.Instrument, (instrument, task_name)))

def _parser(self, value, configParser):
try:
# Exception probably means it is a specific user string such as a URI.
# Let the value return as a string to attempt to continue to
# process as a string, another error will be raised in downstream
# code if that assumption is wrong
with contextlib.suppress(Exception):
value = configParser.visit(ast.parse(value, mode="eval").body)
except Exception:
# This probably means it is a specific user string such as a URI.
# Let the value return as a string to attempt to continue to
# process as a string, another error will be raised in downstream
# code if that assumption is wrong
pass

return value

def applyTo(self, config):
Expand Down
5 changes: 3 additions & 2 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskCo
# by looping over the keys of the defaultTemplates dict specified at
# class declaration time.
templateValues = {
name: getattr(config.connections, name) for name in getattr(cls, "defaultTemplates").keys()
name: getattr(config.connections, name) for name in cls.defaultTemplates # type: ignore
}

# We now assemble a mapping of all connection instances keyed by
Expand Down Expand Up @@ -726,11 +726,12 @@ def buildDatasetRefs(
"""
inputDatasetRefs = InputQuantizedConnection()
outputDatasetRefs = OutputQuantizedConnection()
# operate on a reference object and an interable of names of class
# operate on a reference object and an iterable of names of class
# connection attributes
for refs, names in zip(
(inputDatasetRefs, outputDatasetRefs),
(itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
strict=True,
):
# get a name of a class connection attribute
for attributeName in names:
Expand Down
8 changes: 4 additions & 4 deletions python/lsst/pipe/base/executionButlerBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,25 @@ def _accumulate(
for quantum in (n.quantum for n in graph):
for attrName in ("initInputs", "inputs", "outputs"):
attr: Mapping[DatasetType, DatasetRef | list[DatasetRef]] = getattr(quantum, attrName)
for type, refs in attr.items():
for refs in attr.values():
# This if block is because init inputs has a different
# signature for its items
if not isinstance(refs, (list, tuple)):
if not isinstance(refs, list | tuple):
refs = [refs]
for ref in refs:
if ref.isComponent():
ref = ref.makeCompositeRef()
check_refs.add(ref)
exist_map = butler._exists_many(check_refs, full_check=False)
existing_ids = set(ref.id for ref, exists in exist_map.items() if exists)
existing_ids = {ref.id for ref, exists in exist_map.items() if exists}
del exist_map

for quantum in (n.quantum for n in graph):
for attrName in ("initInputs", "inputs", "outputs"):
attr = getattr(quantum, attrName)

for type, refs in attr.items():
if not isinstance(refs, (list, tuple)):
if not isinstance(refs, list | tuple):
refs = [refs]
if type.component() is not None:
type = type.makeCompositeDatasetType()
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _readBytes(self, start: int, stop: int) -> bytes:
return self._resourceHandle.read(stop - start)

def __enter__(self) -> "LoadHelper":
if isinstance(self.uri, (BinaryIO, BytesIO, BufferedRandom)):
if isinstance(self.uri, BinaryIO | BytesIO | BufferedRandom):
self._resourceHandle = self.uri
else:
self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb"))
Expand Down
16 changes: 7 additions & 9 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def findTaskDefByName(self, taskName: str) -> list[TaskDef]:
multiple times with different labels.
"""
results = []
for task in self._taskToQuantumNode.keys():
for task in self._taskToQuantumNode:
split = task.taskName.split(".")
if split[-1] == taskName:
results.append(task)
Expand All @@ -590,7 +590,7 @@ def findTaskDefByLabel(self, label: str) -> TaskDef | None:
result : `TaskDef`
`TaskDef` objects that has the specified label.
"""
for task in self._taskToQuantumNode.keys():
for task in self._taskToQuantumNode:
if label == task.label:
return task
return None
Expand Down Expand Up @@ -636,10 +636,7 @@ def checkQuantumInGraph(self, quantum: Quantum) -> bool:
in_graph : `bool`
The result of searching for the quantum.
"""
for node in self:
if quantum == node.quantum:
return True
return False
return any(quantum == node.quantum for node in self)

def writeDotGraph(self, output: str | io.BufferedIOBase) -> None:
"""Write out the graph as a dot graph.
Expand Down Expand Up @@ -733,7 +730,7 @@ def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
inputs : `set` of `QuantumNode`
All the nodes that are direct inputs to specified node.
"""
return set(pred for pred in self._connectedQuanta.predecessors(node))
return set(self._connectedQuanta.predecessors(node))

def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
"""Return a set of `QuantumNode` that are direct outputs of a specified
Expand All @@ -749,7 +746,7 @@ def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
outputs : `set` of `QuantumNode`
All the nodes that are direct outputs to specified node.
"""
return set(succ for succ in self._connectedQuanta.successors(node))
return set(self._connectedQuanta.successors(node))

def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
"""Return a graph of `QuantumNode` that are direct inputs and outputs
Expand Down Expand Up @@ -1106,7 +1103,8 @@ def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[byte
count += len(dump)

headerData["DimensionRecords"] = {
key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
key: value.model_dump()
for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
}

# need to serialize this as a series of key,value tuples because of
Expand Down
20 changes: 9 additions & 11 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# -------------------------------
# Imports of standard modules --
# -------------------------------
import contextlib
import itertools
import logging
from collections import ChainMap, defaultdict
Expand Down Expand Up @@ -231,7 +232,7 @@ def dimensions(self) -> DimensionGraph:
base = self.universe.empty
if len(self) == 0:
return base
return base.union(*[datasetType.dimensions for datasetType in self.keys()])
return base.union(*[datasetType.dimensions for datasetType in self])

def unpackSingleRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[DatasetType, DatasetRef]:
"""Unpack nested single-element `~lsst.daf.butler.DatasetRef` dicts
Expand Down Expand Up @@ -564,7 +565,7 @@ def computeSpatialExtent(self, pixelization: PixelizationABC) -> RangeSet:
result = RangeSet()
for dataset_type, datasets in itertools.chain(self.inputs.items(), self.outputs.items()):
if dataset_type.dimensions.spatial:
for data_id in datasets.keys():
for data_id in datasets:
result |= pixelization.envelope(data_id.region)
return result

Expand Down Expand Up @@ -604,7 +605,7 @@ def makeQuantum(self, datastore_records: Mapping[str, DatastoreRecordData] | Non
quantum_records = {}
input_refs = list(itertools.chain.from_iterable(helper.inputs.values()))
input_refs += list(initInputs.values())
input_ids = set(ref.id for ref in input_refs)
input_ids = {ref.id for ref in input_refs}
for datastore_name, records in datastore_records.items():
matching_records = records.subset(input_ids)
if matching_records is not None:
Expand Down Expand Up @@ -893,7 +894,7 @@ def __init__(self, pipeline: Pipeline | Iterable[TaskDef], *, registry: Registry
pipeline = pipeline.toExpandedPipeline()
self.tasks = [
_TaskScaffolding(taskDef=taskDef, parent=self, datasetTypes=taskDatasetTypes)
for taskDef, taskDatasetTypes in zip(pipeline, datasetTypes.byTask.values())
for taskDef, taskDatasetTypes in zip(pipeline, datasetTypes.byTask.values(), strict=True)
]

def __repr__(self) -> str:
Expand Down Expand Up @@ -1050,7 +1051,7 @@ def connectDataIds(
_LOG.debug("Not using dataset existence to constrain query.")
elif datasetQueryConstraint == DatasetQueryConstraintVariant.LIST:
constraint = set(datasetQueryConstraint)
inputs = {k.name: k for k in self.inputs.keys()}
inputs = {k.name: k for k in self.inputs}
if remainder := constraint.difference(inputs.keys()):
raise ValueError(
f"{remainder} dataset type(s) specified as a graph constraint, but"
Expand All @@ -1076,7 +1077,7 @@ def connectDataIds(
# Iterate over query results, populating data IDs for datasets and
# quanta and then connecting them to each other.
n = -1
for n, commonDataId in enumerate(commonDataIds):
for commonDataId in commonDataIds:
# Create DatasetRefs for all DatasetTypes from this result row,
# noting that we might have created some already.
# We remember both those that already existed and those that we
Expand Down Expand Up @@ -1205,11 +1206,8 @@ def resolveDatasetRefs(
# use it for resolving references but don't check it for existing refs.
run_exists = False
if run:
try:
with contextlib.suppress(MissingCollectionError):
run_exists = bool(registry.queryCollections(run))
except MissingCollectionError:
# Undocumented exception is raise if it does not exist
pass

skip_collections_wildcard: CollectionWildcard | None = None
skipExistingInRun = False
Expand Down Expand Up @@ -1647,7 +1645,7 @@ def _get_registry_dataset_types(self, registry: Registry) -> Iterable[DatasetTyp
chain.append(self.globalInitOutputs)

# Collect names of all dataset types.
all_names: set[str] = set(dstype.name for dstype in itertools.chain(*chain))
all_names: set[str] = {dstype.name for dstype in itertools.chain(*chain)}
dataset_types = {ds.name: ds for ds in registry.queryDatasetTypes(all_names)}

# Check for types that do not exist in registry yet:
Expand Down
Loading

0 comments on commit d1bbb55

Please sign in to comment.