Skip to content

Commit

Permalink
Add optional streak masking in detectAndMeasure
Browse files Browse the repository at this point in the history
  • Loading branch information
cmsaunders committed Mar 14, 2024
1 parent 24a354c commit d57d58e
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import lsst.afw.table as afwTable
import lsst.daf.base as dafBase
from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask
from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, MaskStreaksTask
from lsst.meas.base import ForcedMeasurementTask, ApplyApCorrTask, DetectorVisitIdGeneratorConfig
import lsst.meas.extensions.trailedSources # noqa: F401
import lsst.meas.extensions.shapeHSM
Expand Down Expand Up @@ -78,6 +78,17 @@ class DetectAndMeasureConnections(pipeBase.PipelineTaskConnections,
storageClass="ExposureF",
name="{fakesType}{coaddName}Diff_differenceExp",
)
maskedStreaks = pipeBase.connectionTypes.Output(
doc='Streak profile information.',
storageClass="ArrowNumpyDict",
dimensions=("instrument", "visit", "detector"),
name="{coaddName}Diff_streaks",
)

def __init__(self, *, config):
super().__init__(config=config)
if not (self.config.outputStreakInfo and self.config.doMaskStreaks):
self.outputs.remove("maskedStreaks")


class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
Expand Down Expand Up @@ -139,6 +150,22 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
target=SkyObjectsTask,
doc="Generate sky sources",
)
doMaskStreaks = pexConfig.Field(
dtype=bool,
default=False,
doc="Turn on streak masking",
)
maskStreaks = pexConfig.ConfigurableField(
target=MaskStreaksTask,
doc="Subtask for masking streaks. Only used if doMaskStreaks is True. "
"Adds a mask plane to an exposure, with the mask plane name set by streakMaskName.",
)
outputStreakInfo = pexConfig.Field(
dtype=bool,
default=False,
doc="Record the parameters of any detected streaks. For LSST, this should be turned off except for "
"development work."
)
badSourceFlags = lsst.pex.config.ListField(
dtype=str,
doc="Sources with any of these flags set are removed before writing the output catalog.",
Expand Down Expand Up @@ -231,6 +258,8 @@ def __init__(self, **kwargs):
if self.config.doSkySources:
self.makeSubtask("skySources")
self.skySourceKey = self.schema.addField("sky_source", type="Flag", doc="Sky objects.")
if self.config.doMaskStreaks:
self.makeSubtask("maskStreaks")

# Check that the schema and config are consistent
for flag in self.config.badSourceFlags:
Expand Down Expand Up @@ -341,6 +370,9 @@ def processResults(self, science, matchedTemplate, difference, sources, table,
initialDiaSources = sources
self.metadata.add("nMergedDiaSources", len(initialDiaSources))

if self.config.doMaskStreaks:
streakInfo = self._runStreakMasking(difference.maskedImage)

if self.config.doSkySources:
self.addSkySources(initialDiaSources, difference.mask, difference.info.id)

Expand All @@ -354,6 +386,9 @@ def processResults(self, science, matchedTemplate, difference, sources, table,
subtractedMeasuredExposure=difference,
diaSources=diaSources,
)
if self.config.doMaskStreaks and self.config.outputStreakInfo:
measurementResults.mergeItems(streakInfo, 'maskedStreaks')

self.calculateMetrics(difference)

return measurementResults
Expand Down Expand Up @@ -490,6 +525,28 @@ def calculateMetrics(self, difference):
self.metadata.add("nBadPixelsDetectedPositive", np.sum(detPosPix))
self.metadata.add("nBadPixelsDetectedNegative", np.sum(detNegPix))

def _runStreakMasking(self, maskedImage):
"""Do streak masking at put results into catalog.
Parameters
----------
maskedImage: `lsst.afw.image.maskedImage`
The image in which to search for streaks. Must have a detection
mask.
Returns
-------
streakInfo: `dict` ['str', 'float']
Catalog of streak profile information.
"""
streaks = self.maskStreaks.run(maskedImage)
rhos = np.array([line.rho for line in streaks.lines])
thetas = np.array([line.theta for line in streaks.lines])
sigmas = np.array([line.sigma for line in streaks.lines])

streakInfo = {'rho': rhos, 'theta': thetas, 'sigma': sigmas}
return pipeBase.Struct(maskedStreaks=streakInfo)


class DetectAndMeasureScoreConnections(DetectAndMeasureConnections):
scoreExposure = pipeBase.connectionTypes.Input(
Expand Down

0 comments on commit d57d58e

Please sign in to comment.