Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47738: Update for scarlet lite changes #1018

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ class DeblendCoaddSourcesMultiConnections(PipelineTaskConnections,
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
deconvolvedCoadds = cT.Input(
doc="Deconvolved coadds",
name="deconvolved_{inputCoaddName}_coadd",
storageClass="ExposureF",
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
outputSchema = cT.InitOutput(
doc="Output of the schema used in deblending task",
name="{outputCoaddName}Coadd_deblendedFlux_schema",
Expand Down Expand Up @@ -251,14 +258,26 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputRefs = reorderRefs(inputRefs, bandOrder, dataIdKey="band")
inputs = butlerQC.get(inputRefs)
inputs["idFactory"] = self.config.idGenerator.apply(butlerQC.quantum.dataId).make_table_id_factory()
inputs["filters"] = [dRef.dataId["band"] for dRef in inputRefs.coadds]

# Ensure that the coadd bands and deconvolved coadd bands match
bands = [dRef.dataId["band"] for dRef in inputRefs.coadds]
deconvBands = [dRef.dataId["band"] for dRef in inputRefs.deconvolvedCoadds]
if len(bands) != len(deconvBands):
raise RuntimeError("Number of coadd bands and deconvolved coadd bands do not match")

for band, deconvBand in zip(bands, deconvBands):
if band != deconvBand:
raise RuntimeError(f"Bands {band} and {deconvBand} do not match")

inputs["bands"] = [dRef.dataId["band"] for dRef in inputRefs.coadds]
outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def run(self, coadds, filters, mergedDetections, idFactory):
def run(self, coadds, bands, mergedDetections, deconvolvedCoadds, idFactory):
sources = self._makeSourceCatalog(mergedDetections, idFactory)
multiExposure = afwImage.MultibandExposure.fromExposures(filters, coadds)
catalog, modelData = self.multibandDeblend.run(multiExposure, sources)
multiExposure = afwImage.MultibandExposure.fromExposures(bands, coadds)
mDeconvolved = afwImage.MultibandExposure.fromExposures(bands, deconvolvedCoadds)
catalog, modelData = self.multibandDeblend.run(multiExposure, mDeconvolved, sources)
retStruct = Struct(deblendedCatalog=catalog, scarletModelData=modelData)
return retStruct

Expand Down
10 changes: 9 additions & 1 deletion tests/test_isPrimaryFlag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask, SetPrimaryFlagsTask
import lsst.meas.extensions.scarlet as mes
from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask
from lsst.meas.base import SingleFrameMeasurementTask
from lsst.afw.table import SourceCatalog

Expand Down Expand Up @@ -214,6 +215,10 @@ def testIsScarletPrimaryFlag(self):
skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
schema.addField("merge_peak_sky", type="Flag")

# Initialize the deconvolution task
deconvolveConfig = DeconvolveExposureTask.ConfigClass()
deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig)

# Initialize the deblender task
scarletConfig = ScarletDeblendTask.ConfigClass()
scarletConfig.maxIter = 20
Expand Down Expand Up @@ -243,8 +248,11 @@ def testIsScarletPrimaryFlag(self):
src = catalog.addNew()
src.setFootprint(foot)
src.set("merge_peak_sky", True)
# deconvolve the images
deconvolved = deconvolveTask.run(coadds["test"], catalog).deconvolved
mDeconvolved = afwImage.MultibandExposure.fromExposures(["test"], [deconvolved])
# deblend
catalog, modelData = deblendTask.run(coadds, catalog)
catalog, modelData = deblendTask.run(coadds, mDeconvolved, catalog)
# Attach footprints to the catalog
mes.io.updateCatalogFootprints(
modelData=modelData,
Expand Down
Loading