From 981534af2cdf1ad60f66e0bcacf3d121755a0f28 Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Wed, 13 Mar 2024 14:52:52 -0700 Subject: [PATCH] Add optional streak masking in detectAndMeasure --- python/lsst/ip/diffim/detectAndMeasure.py | 65 ++++++++++++++++++++++- tests/test_detectAndMeasure.py | 31 +++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/python/lsst/ip/diffim/detectAndMeasure.py b/python/lsst/ip/diffim/detectAndMeasure.py index f350e83d..fddd737b 100644 --- a/python/lsst/ip/diffim/detectAndMeasure.py +++ b/python/lsst/ip/diffim/detectAndMeasure.py @@ -26,7 +26,7 @@ import lsst.daf.base as dafBase import lsst.geom from lsst.ip.diffim.utils import getPsfFwhm, angleMean -from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask +from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask, MaskStreaksTask from lsst.meas.base import ForcedMeasurementTask, ApplyApCorrTask, DetectorVisitIdGeneratorConfig import lsst.meas.deblender import lsst.meas.extensions.trailedSources # noqa: F401 @@ -89,12 +89,21 @@ class DetectAndMeasureConnections(pipeBase.PipelineTaskConnections, storageClass="ArrowAstropy", name="{fakesType}{coaddName}Diff_spatiallySampledMetrics", ) + maskedStreaks = pipeBase.connectionTypes.Output( + doc='Streak profile information.', + storageClass="ArrowNumpyDict", + dimensions=("instrument", "visit", "detector"), + name="{fakesType}{coaddName}Diff_streaks", + ) - def __init__(self, *, config=None): + def __init__(self, *, config): super().__init__(config=config) if not config.doWriteMetrics: self.outputs.remove("spatiallySampledMetrics") + if not (self.config.writeStreakInfo and self.config.doMaskStreaks): + self.outputs.remove("maskedStreaks") + class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig, pipelineConnections=DetectAndMeasureConnections): @@ -159,6 +168,22 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig, target=SkyObjectsTask, doc="Generate sky sources", ) + doMaskStreaks = pexConfig.Field( + dtype=bool, + default=False, + doc="Turn on streak masking", + ) + maskStreaks = pexConfig.ConfigurableField( + target=MaskStreaksTask, + doc="Subtask for masking streaks. Only used if doMaskStreaks is True. " + "Adds a mask plane to an exposure, with the mask plane name set by streakMaskName.", + ) + writeStreakInfo = pexConfig.Field( + dtype=bool, + default=False, + doc="Record the parameters of any detected streaks. For LSST, this should be turned off except for " + "development work." + ) setPrimaryFlags = pexConfig.ConfigurableField( target=SetPrimaryFlagsTask, doc="Task to add isPrimary and deblending-related flags to the catalog." @@ -276,6 +301,8 @@ def __init__(self, **kwargs): self.schema.addField("srcMatchId", "L", "unique id of source match") if self.config.doSkySources: self.makeSubtask("skySources", schema=self.schema) + if self.config.doMaskStreaks: + self.makeSubtask("maskStreaks") # Check that the schema and config are consistent for flag in self.config.badSourceFlags: @@ -464,6 +491,9 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor self.metadata.add("nMergedDiaSources", len(initialDiaSources)) + if self.config.doMaskStreaks: + streakInfo = self._runStreakMasking(difference.maskedImage) + if self.config.doSkySources: self.addSkySources(initialDiaSources, difference.mask, difference.info.id) @@ -484,6 +514,8 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor diaSources=diaSources, spatiallySampledMetrics=spatiallySampledMetrics, ) + if self.config.doMaskStreaks and self.config.writeStreakInfo: + measurementResults.mergeItems(streakInfo, 'maskedStreaks') return measurementResults @@ -756,6 +788,35 @@ def _evaluateLocalMetric(self, src, diaSources, science, matchedTemplate, differ evaluateMaskFraction(difference.mask[bbox], maskPlane) ) + def _runStreakMasking(self, maskedImage): + """Do streak masking at put results into catalog. + + Parameters + ---------- + maskedImage: `lsst.afw.image.maskedImage` + The image in which to search for streaks. Must have a detection + mask. + + Returns + ------- + streakInfo: `lsst.pipe.base.Struct` + ``rho`` : `np.ndarray` + Angle of detected streak. + ``theta`` : `np.ndarray` + Distance from center of detected streak. + ``sigma`` : `np.ndarray` + Width of streak profile. + """ + streaks = self.maskStreaks.run(maskedImage) + if self.config.writeStreakInfo: + rhos = np.array([line.rho for line in streaks.lines]) + thetas = np.array([line.theta for line in streaks.lines]) + sigmas = np.array([line.sigma for line in streaks.lines]) + streakInfo = {'rho': rhos, 'theta': thetas, 'sigma': sigmas} + else: + streakInfo = {'rho': np.array([]), 'theta': np.array([]), 'sigma': np.array([])} + return pipeBase.Struct(maskedStreaks=streakInfo) + class DetectAndMeasureScoreConnections(DetectAndMeasureConnections): scoreExposure = pipeBase.connectionTypes.Input( diff --git a/tests/test_detectAndMeasure.py b/tests/test_detectAndMeasure.py index 6e7f3ea9..8a367c36 100644 --- a/tests/test_detectAndMeasure.py +++ b/tests/test_detectAndMeasure.py @@ -532,6 +532,37 @@ def test_fake_mask_plane_propagation(self): self.assertFalse(diaSrc['base_PixelFlags_flag_injected']) self.assertFalse(diaSrc['base_PixelFlags_flag_injectedCenter']) + def test_mask_streaks(self): + """Run detection on a difference image containing a streak. + """ + # Set up the simulated images + noiseLevel = 1. + staticSeed = 1 + fluxLevel = 500 + kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel} + science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) + + # Configure the detection Task + detectionTask = self._setup_detection(doMerge=False, doMaskStreaks=True) + + # Test that no streaks are detected + difference = science.clone() + difference.maskedImage -= matchedTemplate.maskedImage + output = detectionTask.run(science, matchedTemplate, difference) + outMask = output.subtractedMeasuredExposure.mask.array + streakMask = output.subtractedMeasuredExposure.mask.getPlaneBitMask("STREAK") + streakMaskSet = (outMask & streakMask) > 0 + self.assertTrue(np.all(streakMaskSet == 0)) + + # Add streak-like shape and check that streak is detected + difference.image.array[20:23, 40:200] += 50 + output = detectionTask.run(science, matchedTemplate, difference) + outMask = output.subtractedMeasuredExposure.mask.array + streakMask = output.subtractedMeasuredExposure.mask.getPlaneBitMask("STREAK") + streakMaskSet = (outMask & streakMask) > 0 + self.assertTrue(np.all(streakMaskSet[20:23, 40:200])) + class DetectAndMeasureScoreTest(DetectAndMeasureTestBase, lsst.utils.tests.TestCase): detectionTask = detectAndMeasure.DetectAndMeasureScoreTask