Skip to content

Commit

Permalink
Merge pull request #303 from lsst/tickets/DM-43083
Browse files Browse the repository at this point in the history
DM-43083: Add optional streak masking in detectAndMeasure
  • Loading branch information
cmsaunders authored Apr 16, 2024
2 parents 5d7b63b + e09ba32 commit 6074fc1
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 2 deletions.
65 changes: 63 additions & 2 deletions python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import lsst.daf.base as dafBase
import lsst.geom
from lsst.ip.diffim.utils import getPsfFwhm, angleMean
from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask
from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask, MaskStreaksTask
from lsst.meas.base import ForcedMeasurementTask, ApplyApCorrTask, DetectorVisitIdGeneratorConfig
import lsst.meas.deblender
import lsst.meas.extensions.trailedSources # noqa: F401
Expand Down Expand Up @@ -89,12 +89,21 @@ class DetectAndMeasureConnections(pipeBase.PipelineTaskConnections,
storageClass="ArrowAstropy",
name="{fakesType}{coaddName}Diff_spatiallySampledMetrics",
)
maskedStreaks = pipeBase.connectionTypes.Output(
doc='Streak profile information.',
storageClass="ArrowNumpyDict",
dimensions=("instrument", "visit", "detector"),
name="{fakesType}{coaddName}Diff_streaks",
)

def __init__(self, *, config=None):
def __init__(self, *, config):
super().__init__(config=config)
if not config.doWriteMetrics:
self.outputs.remove("spatiallySampledMetrics")

if not (self.config.writeStreakInfo and self.config.doMaskStreaks):
self.outputs.remove("maskedStreaks")


class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=DetectAndMeasureConnections):
Expand Down Expand Up @@ -159,6 +168,22 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
target=SkyObjectsTask,
doc="Generate sky sources",
)
doMaskStreaks = pexConfig.Field(
dtype=bool,
default=False,
doc="Turn on streak masking",
)
maskStreaks = pexConfig.ConfigurableField(
target=MaskStreaksTask,
doc="Subtask for masking streaks. Only used if doMaskStreaks is True. "
"Adds a mask plane to an exposure, with the mask plane name set by streakMaskName.",
)
writeStreakInfo = pexConfig.Field(
dtype=bool,
default=False,
doc="Record the parameters of any detected streaks. For LSST, this should be turned off except for "
"development work."
)
setPrimaryFlags = pexConfig.ConfigurableField(
target=SetPrimaryFlagsTask,
doc="Task to add isPrimary and deblending-related flags to the catalog."
Expand Down Expand Up @@ -275,6 +300,8 @@ def __init__(self, **kwargs):
self.schema.addField("srcMatchId", "L", "unique id of source match")
if self.config.doSkySources:
self.makeSubtask("skySources", schema=self.schema)
if self.config.doMaskStreaks:
self.makeSubtask("maskStreaks")

# Check that the schema and config are consistent
for flag in self.config.badSourceFlags:
Expand Down Expand Up @@ -463,6 +490,9 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor

self.metadata.add("nMergedDiaSources", len(initialDiaSources))

if self.config.doMaskStreaks:
streakInfo = self._runStreakMasking(difference.maskedImage)

if self.config.doSkySources:
self.addSkySources(initialDiaSources, difference.mask, difference.info.id)

Expand All @@ -483,6 +513,8 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor
diaSources=diaSources,
spatiallySampledMetrics=spatiallySampledMetrics,
)
if self.config.doMaskStreaks and self.config.writeStreakInfo:
measurementResults.mergeItems(streakInfo, 'maskedStreaks')

return measurementResults

Expand Down Expand Up @@ -755,6 +787,35 @@ def _evaluateLocalMetric(self, src, diaSources, science, matchedTemplate, differ
evaluateMaskFraction(difference.mask[bbox], maskPlane)
)

def _runStreakMasking(self, maskedImage):
"""Do streak masking at put results into catalog.
Parameters
----------
maskedImage: `lsst.afw.image.maskedImage`
The image in which to search for streaks. Must have a detection
mask.
Returns
-------
streakInfo: `lsst.pipe.base.Struct`
``rho`` : `np.ndarray`
Angle of detected streak.
``theta`` : `np.ndarray`
Distance from center of detected streak.
``sigma`` : `np.ndarray`
Width of streak profile.
"""
streaks = self.maskStreaks.run(maskedImage)
if self.config.writeStreakInfo:
rhos = np.array([line.rho for line in streaks.lines])
thetas = np.array([line.theta for line in streaks.lines])
sigmas = np.array([line.sigma for line in streaks.lines])
streakInfo = {'rho': rhos, 'theta': thetas, 'sigma': sigmas}
else:
streakInfo = {'rho': np.array([]), 'theta': np.array([]), 'sigma': np.array([])}
return pipeBase.Struct(maskedStreaks=streakInfo)


class DetectAndMeasureScoreConnections(DetectAndMeasureConnections):
scoreExposure = pipeBase.connectionTypes.Input(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,37 @@ def test_fake_mask_plane_propagation(self):
self.assertFalse(diaSrc['base_PixelFlags_flag_injected'])
self.assertFalse(diaSrc['base_PixelFlags_flag_injectedCenter'])

def test_mask_streaks(self):
"""Run detection on a difference image containing a streak.
"""
# Set up the simulated images
noiseLevel = 1.
staticSeed = 1
fluxLevel = 500
kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel}
science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs)
matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs)

# Configure the detection Task
detectionTask = self._setup_detection(doMerge=False, doMaskStreaks=True)

# Test that no streaks are detected
difference = science.clone()
difference.maskedImage -= matchedTemplate.maskedImage
output = detectionTask.run(science, matchedTemplate, difference)
outMask = output.subtractedMeasuredExposure.mask.array
streakMask = output.subtractedMeasuredExposure.mask.getPlaneBitMask("STREAK")
streakMaskSet = (outMask & streakMask) > 0
self.assertTrue(np.all(streakMaskSet == 0))

# Add streak-like shape and check that streak is detected
difference.image.array[20:23, 40:200] += 50
output = detectionTask.run(science, matchedTemplate, difference)
outMask = output.subtractedMeasuredExposure.mask.array
streakMask = output.subtractedMeasuredExposure.mask.getPlaneBitMask("STREAK")
streakMaskSet = (outMask & streakMask) > 0
self.assertTrue(np.all(streakMaskSet[20:23, 40:200]))


class DetectAndMeasureScoreTest(DetectAndMeasureTestBase, lsst.utils.tests.TestCase):
detectionTask = detectAndMeasure.DetectAndMeasureScoreTask
Expand Down

0 comments on commit 6074fc1

Please sign in to comment.