From d18972cf846a870feee88698cbfcfc0c9200734e Mon Sep 17 00:00:00 2001 From: fred3m Date: Mon, 23 Dec 2024 20:49:44 -0800 Subject: [PATCH] Update for scarlet lite changes --- .../pipe/tasks/deblendCoaddSourcesPipeline.py | 27 ++++++++++++++++--- tests/test_isPrimaryFlag.py | 10 ++++++- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py b/python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py index 393bcd98f..577d25ae8 100644 --- a/python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py +++ b/python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py @@ -120,6 +120,13 @@ class DeblendCoaddSourcesMultiConnections(PipelineTaskConnections, multiple=True, dimensions=("tract", "patch", "band", "skymap") ) + deconvolvedCoadds = cT.Input( + doc="Deconvolved coadds", + name="deconvolved_{inputCoaddName}_coadd", + storageClass="ExposureF", + multiple=True, + dimensions=("tract", "patch", "band", "skymap") + ) outputSchema = cT.InitOutput( doc="Output of the schema used in deblending task", name="{outputCoaddName}Coadd_deblendedFlux_schema", @@ -251,14 +258,26 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputRefs = reorderRefs(inputRefs, bandOrder, dataIdKey="band") inputs = butlerQC.get(inputRefs) inputs["idFactory"] = self.config.idGenerator.apply(butlerQC.quantum.dataId).make_table_id_factory() - inputs["filters"] = [dRef.dataId["band"] for dRef in inputRefs.coadds] + + # Ensure that the coadd bands and deconvolved coadd bands match + bands = [dRef.dataId["band"] for dRef in inputRefs.coadds] + deconvBands = [dRef.dataId["band"] for dRef in inputRefs.deconvolvedCoadds] + if len(bands) != len(deconvBands): + raise RuntimeError("Number of coadd bands and deconvolved coadd bands do not match") + + for band, deconvBand in zip(bands, deconvBands): + if band != deconvBand: + raise RuntimeError(f"Bands {band} and {deconvBand} do not match") + + inputs["bands"] = [dRef.dataId["band"] for dRef in inputRefs.coadds] outputs = self.run(**inputs) butlerQC.put(outputs, outputRefs) - def run(self, coadds, filters, mergedDetections, idFactory): + def run(self, coadds, bands, mergedDetections, deconvolvedCoadds, idFactory): sources = self._makeSourceCatalog(mergedDetections, idFactory) - multiExposure = afwImage.MultibandExposure.fromExposures(filters, coadds) - catalog, modelData = self.multibandDeblend.run(multiExposure, sources) + multiExposure = afwImage.MultibandExposure.fromExposures(bands, coadds) + mDeconvolved = afwImage.MultibandExposure.fromExposures(bands, deconvolvedCoadds) + catalog, modelData = self.multibandDeblend.run(multiExposure, mDeconvolved, sources) retStruct = Struct(deblendedCatalog=catalog, scarletModelData=modelData) return retStruct diff --git a/tests/test_isPrimaryFlag.py b/tests/test_isPrimaryFlag.py index 7b247d1a6..e7ebb535c 100755 --- a/tests/test_isPrimaryFlag.py +++ b/tests/test_isPrimaryFlag.py @@ -33,6 +33,7 @@ from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask, SetPrimaryFlagsTask import lsst.meas.extensions.scarlet as mes from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask +from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask from lsst.meas.base import SingleFrameMeasurementTask from lsst.afw.table import SourceCatalog @@ -214,6 +215,10 @@ def testIsScarletPrimaryFlag(self): skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig) schema.addField("merge_peak_sky", type="Flag") + # Initialize the deconvolution task + deconvolveConfig = DeconvolveExposureTask.ConfigClass() + deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig) + # Initialize the deblender task scarletConfig = ScarletDeblendTask.ConfigClass() scarletConfig.maxIter = 20 @@ -243,8 +248,11 @@ def testIsScarletPrimaryFlag(self): src = catalog.addNew() src.setFootprint(foot) src.set("merge_peak_sky", True) + # deconvolve the images + deconvolved = deconvolveTask.run(coadds["test"], catalog).deconvolved + mDeconvolved = afwImage.MultibandExposure.fromExposures(["test"], [deconvolved]) # deblend - catalog, modelData = deblendTask.run(coadds, catalog) + catalog, modelData = deblendTask.run(coadds, mDeconvolved, catalog) # Attach footprints to the catalog mes.io.updateCatalogFootprints( modelData=modelData,