Skip to content

Commit

Permalink
Merge pull request #40 from lsst/tickets/DM-41490
Browse files Browse the repository at this point in the history
DM-41490: Recover WCS for input with astrometry failures
  • Loading branch information
cmsaunders authored Jan 19, 2024
2 parents 115f0b7 + ac40fb2 commit 21efbaf
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 20 deletions.
109 changes: 89 additions & 20 deletions python/lsst/drp/tasks/gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def run(
self.log.info("Fit the WCSs")
# Set up a YAML-type string using the config variables and a sample
# visit
inputYAML = self.make_yaml(inputVisitSummaries[0])
inputYAML, mapTemplate = self.make_yaml(inputVisitSummaries[0])

# Set the verbosity level for WCSFit from the task log level.
# TODO: DM-36850, Add lsst.log to gbdes so that log messages are
Expand Down Expand Up @@ -590,7 +590,7 @@ def run(
)
self.log.info("WCS fitting done")

outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo)
outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo, mapTemplate=mapTemplate)
outputCatalog = wcsf.getOutputCatalog()
starCatalog = wcsf.getStarCatalog()
modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None
Expand Down Expand Up @@ -630,6 +630,7 @@ def _prep_sky(self, inputVisitSummaries, epoch, fieldName="Field"):
detectorCorners = [
lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees).getVector()
for (ra, dec) in zip(visSum["raCorners"].ravel(), visSum["decCorners"].ravel())
if (np.isfinite(ra) and (np.isfinite(dec)))
]
allDetectorCorners.extend(detectorCorners)
boundingCircle = lsst.sphgeom.ConvexPolygon.convexHull(allDetectorCorners).getBoundingCircle()
Expand Down Expand Up @@ -738,6 +739,24 @@ def _get_exposure_info(

for row in visitSummary:
detector = row["id"]

wcs = row.getWcs()
if wcs is None:
self.log.warning(
"WCS is None for visit %d, detector %d: this extension (visit/detector) will be "
"dropped.",
visit,
detector,
)
continue
else:
wcsRA = wcs.getSkyOrigin().getRa().asRadians()
wcsDec = wcs.getSkyOrigin().getDec().asRadians()
tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec)
mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC")
gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint)
wcss.append(gbdes_wcs)

if detector not in detectors:
detectors.append(detector)
detectorBounds = wcsfit.Bounds(
Expand All @@ -752,14 +771,6 @@ def _get_exposure_info(
extensionDetectors.append(detector)
extensionType.append("SCIENCE")

wcs = row.getWcs()
wcsRA = wcs.getSkyOrigin().getRa().asRadians()
wcsDec = wcs.getSkyOrigin().getDec().asRadians()
tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec)
mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC")
gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint)
wcss.append(gbdes_wcs)

fieldNumbers = list(np.ones(len(exposureNames), dtype=int) * fieldNumber)
instrumentNumbers = list(np.ones(len(exposureNames), dtype=int) * instrumentNumber)

Expand Down Expand Up @@ -926,6 +937,34 @@ def _load_refcat(

return refObjects, refCovariance

@staticmethod
def _find_extension_index(extensionInfo, visit, detector):
"""Find the index for a given extension from its visit and detector
number.
If no match is found, None is returned.
Parameters
----------
extensionInfo : `lsst.pipe.base.Struct`
Struct containing properties for each extension.
visit : `int`
Visit number
detector : `int`
Detector number
Returns
-------
extensionIndex : `int` or None
Index of this extension
"""
findExtension = np.flatnonzero((extensionInfo.visit == visit) & (extensionInfo.detector == detector))
if len(findExtension) == 0:
extensionIndex = None
else:
extensionIndex = findExtension[0]
return extensionIndex

def _load_catalogs_and_associate(
self, associations, inputCatalogRefs, extensionInfo, fieldIndex=0, instrumentIndex=0
):
Expand Down Expand Up @@ -999,9 +1038,12 @@ class `wcsfit.FoFClass`, associating them into matches as you go.
goodInds = selected.selected & goodShapes

isStar = np.ones(goodInds.sum())
extensionIndex = np.flatnonzero(
(extensionInfo.visit == visit) & (extensionInfo.detector == detector)
)[0]
extensionIndex = self._find_extension_index(extensionInfo, visit, detector)
if extensionIndex is None:
# This extension does not have information necessary for
# fit. Skip the detections from this detector for this
# visit.
continue
detectorIndex = extensionInfo.detectorIndex[extensionIndex]
visitIndex = extensionInfo.visitIndex[extensionIndex]

Expand Down Expand Up @@ -1087,6 +1129,8 @@ def make_yaml(self, inputVisitSummary, inputFile=None):
-------
inputYAML : `wcsfit.YAMLCollector`
YAML object containing the model description.
inputDict : `dict` [`str`, `str`]
Dictionary containing the model description.
"""
if inputFile is not None:
inputYAML = wcsfit.YAMLCollector(inputFile, "PixelMapCollection")
Expand Down Expand Up @@ -1137,7 +1181,7 @@ def make_yaml(self, inputVisitSummary, inputFile=None):
inputYAML.addInput(yaml.dump(inputDict))
inputYAML.addInput("Identity:\n Type: Identity\n")

return inputYAML
return inputYAML, inputDict

def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, columns):
"""Add science sources to the wcsfit.WCSFit object.
Expand All @@ -1164,9 +1208,13 @@ def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, col
for detector in detectors:
detectorSources = inputCatalog[inputCatalog["detector"] == detector]

extensionIndex = np.flatnonzero(
(extensionInfo.visit == visit) & (extensionInfo.detector == detector)
)[0]
extensionIndex = self._find_extension_index(extensionInfo, visit, detector)
if extensionIndex is None:
# This extension does not have information necessary for
# fit. Skip the detections from this detector for this
# visit.
continue

sourceCat = detectorSources[sourceIndices[extensionIndex]]

xCov = sourceCat["xErr"] ** 2
Expand Down Expand Up @@ -1288,7 +1336,7 @@ def _make_afw_wcs(self, mapDict, centerRA, centerDec, doNormalizePixels=False, x
outWCS = afwgeom.SkyWcs(frameDict)
return outWCS

def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo):
def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None):
"""Make a WCS object out of the WCS models.
Parameters
Expand Down Expand Up @@ -1335,8 +1383,29 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo):

for d, detector in enumerate(visitSummary["id"]):
mapName = f"{visit}/{detector}"

mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base")
if mapName in wcsf.mapCollection.allMapNames():
mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base")
else:
# This extension was not fit, but the WCS can be recovered
# using the maps fit from sources on other visits but the
# same detector and from sources on other detectors from
# this visit.
genericElements = mapTemplate["EXPOSURE/DEVICE/base"]["Elements"]
mapElements = []
instrument = visitSummary[0].getVisitInfo().instrumentLabel
# Go through the generic map components to build the names
# of the specific maps for this extension.
for component in genericElements:
elements = mapTemplate[component]["Elements"]
for element in elements:
# TODO: DM-42519, gbdes sets the "BAND" to the
# instrument name currently. This will need to be
# disambiguated if we run on multiple bands at
# once.
element = element.replace("BAND", str(instrument))
element = element.replace("EXPOSURE", str(visit))
element = element.replace("DEVICE", str(detector))
mapElements.append(element)
mapDict = {}
for m, mapElement in enumerate(mapElements):
mapType = wcsf.mapCollection.getMapType(mapElement)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,31 @@ def test_run(self):
np.testing.assert_array_less(absDiffX, 1e-7)
np.testing.assert_array_less(absDiffY, 1e-7)

def test_missingWcs(self):
"""Test that task does not fail when the input WCS is None for one
extension and that the fit WCS for that extension returns a finite
result.
"""
inputVisitSummary = self.inputVisitSummary.copy()
# Set one WCS to be None
testVisit = 0
testDetector = 20
inputVisitSummary[testVisit][testDetector].setWcs(None)

outputs = self.task.run(
self.inputCatalogRefs,
inputVisitSummary,
instrumentName=self.instrumentName,
refEpoch=self.refEpoch,
refObjectLoader=self.refObjectLoader,
)

# Check that the fit WCS for the extension with input WCS=None returns
# finite sky values.
testWcs = outputs.outputWCSs[self.testVisits[testVisit]][testDetector].getWcs()
testSky = testWcs.pixelToSky(0, 0)
self.assertTrue(testSky.isFinite())


def setup_module(module):
lsst.utils.tests.init()
Expand Down

0 comments on commit 21efbaf

Please sign in to comment.