From ac40fb248ca96b5db4c4f8190bb64055fa9edf41 Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Fri, 15 Dec 2023 11:26:45 -0800 Subject: [PATCH] Recover WCS for input with astrometry failures --- python/lsst/drp/tasks/gbdesAstrometricFit.py | 109 +++++++++++++++---- tests/test_gbdesAstrometricFit.py | 25 +++++ 2 files changed, 114 insertions(+), 20 deletions(-) diff --git a/python/lsst/drp/tasks/gbdesAstrometricFit.py b/python/lsst/drp/tasks/gbdesAstrometricFit.py index d7966c02..d2bc2327 100644 --- a/python/lsst/drp/tasks/gbdesAstrometricFit.py +++ b/python/lsst/drp/tasks/gbdesAstrometricFit.py @@ -542,7 +542,7 @@ def run( self.log.info("Fit the WCSs") # Set up a YAML-type string using the config variables and a sample # visit - inputYAML = self.make_yaml(inputVisitSummaries[0]) + inputYAML, mapTemplate = self.make_yaml(inputVisitSummaries[0]) # Set the verbosity level for WCSFit from the task log level. # TODO: DM-36850, Add lsst.log to gbdes so that log messages are @@ -590,7 +590,7 @@ def run( ) self.log.info("WCS fitting done") - outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo) + outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo, mapTemplate=mapTemplate) outputCatalog = wcsf.getOutputCatalog() starCatalog = wcsf.getStarCatalog() modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None @@ -630,6 +630,7 @@ def _prep_sky(self, inputVisitSummaries, epoch, fieldName="Field"): detectorCorners = [ lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees).getVector() for (ra, dec) in zip(visSum["raCorners"].ravel(), visSum["decCorners"].ravel()) + if (np.isfinite(ra) and (np.isfinite(dec))) ] allDetectorCorners.extend(detectorCorners) boundingCircle = lsst.sphgeom.ConvexPolygon.convexHull(allDetectorCorners).getBoundingCircle() @@ -738,6 +739,24 @@ def _get_exposure_info( for row in visitSummary: detector = row["id"] + + wcs = row.getWcs() + if wcs is None: + self.log.warning( + "WCS is None for visit %d, detector %d: this extension (visit/detector) will be " + "dropped.", + visit, + detector, + ) + continue + else: + wcsRA = wcs.getSkyOrigin().getRa().asRadians() + wcsDec = wcs.getSkyOrigin().getDec().asRadians() + tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec) + mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC") + gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint) + wcss.append(gbdes_wcs) + if detector not in detectors: detectors.append(detector) detectorBounds = wcsfit.Bounds( @@ -752,14 +771,6 @@ def _get_exposure_info( extensionDetectors.append(detector) extensionType.append("SCIENCE") - wcs = row.getWcs() - wcsRA = wcs.getSkyOrigin().getRa().asRadians() - wcsDec = wcs.getSkyOrigin().getDec().asRadians() - tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec) - mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC") - gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint) - wcss.append(gbdes_wcs) - fieldNumbers = list(np.ones(len(exposureNames), dtype=int) * fieldNumber) instrumentNumbers = list(np.ones(len(exposureNames), dtype=int) * instrumentNumber) @@ -926,6 +937,34 @@ def _load_refcat( return refObjects, refCovariance + @staticmethod + def _find_extension_index(extensionInfo, visit, detector): + """Find the index for a given extension from its visit and detector + number. + + If no match is found, None is returned. + + Parameters + ---------- + extensionInfo : `lsst.pipe.base.Struct` + Struct containing properties for each extension. + visit : `int` + Visit number + detector : `int` + Detector number + + Returns + ------- + extensionIndex : `int` or None + Index of this extension + """ + findExtension = np.flatnonzero((extensionInfo.visit == visit) & (extensionInfo.detector == detector)) + if len(findExtension) == 0: + extensionIndex = None + else: + extensionIndex = findExtension[0] + return extensionIndex + def _load_catalogs_and_associate( self, associations, inputCatalogRefs, extensionInfo, fieldIndex=0, instrumentIndex=0 ): @@ -999,9 +1038,12 @@ class `wcsfit.FoFClass`, associating them into matches as you go. goodInds = selected.selected & goodShapes isStar = np.ones(goodInds.sum()) - extensionIndex = np.flatnonzero( - (extensionInfo.visit == visit) & (extensionInfo.detector == detector) - )[0] + extensionIndex = self._find_extension_index(extensionInfo, visit, detector) + if extensionIndex is None: + # This extension does not have information necessary for + # fit. Skip the detections from this detector for this + # visit. + continue detectorIndex = extensionInfo.detectorIndex[extensionIndex] visitIndex = extensionInfo.visitIndex[extensionIndex] @@ -1087,6 +1129,8 @@ def make_yaml(self, inputVisitSummary, inputFile=None): ------- inputYAML : `wcsfit.YAMLCollector` YAML object containing the model description. + inputDict : `dict` [`str`, `str`] + Dictionary containing the model description. """ if inputFile is not None: inputYAML = wcsfit.YAMLCollector(inputFile, "PixelMapCollection") @@ -1137,7 +1181,7 @@ def make_yaml(self, inputVisitSummary, inputFile=None): inputYAML.addInput(yaml.dump(inputDict)) inputYAML.addInput("Identity:\n Type: Identity\n") - return inputYAML + return inputYAML, inputDict def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, columns): """Add science sources to the wcsfit.WCSFit object. @@ -1164,9 +1208,13 @@ def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, col for detector in detectors: detectorSources = inputCatalog[inputCatalog["detector"] == detector] - extensionIndex = np.flatnonzero( - (extensionInfo.visit == visit) & (extensionInfo.detector == detector) - )[0] + extensionIndex = self._find_extension_index(extensionInfo, visit, detector) + if extensionIndex is None: + # This extension does not have information necessary for + # fit. Skip the detections from this detector for this + # visit. + continue + sourceCat = detectorSources[sourceIndices[extensionIndex]] xCov = sourceCat["xErr"] ** 2 @@ -1288,7 +1336,7 @@ def _make_afw_wcs(self, mapDict, centerRA, centerDec, doNormalizePixels=False, x outWCS = afwgeom.SkyWcs(frameDict) return outWCS - def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo): + def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None): """Make a WCS object out of the WCS models. Parameters @@ -1335,8 +1383,29 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo): for d, detector in enumerate(visitSummary["id"]): mapName = f"{visit}/{detector}" - - mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base") + if mapName in wcsf.mapCollection.allMapNames(): + mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base") + else: + # This extension was not fit, but the WCS can be recovered + # using the maps fit from sources on other visits but the + # same detector and from sources on other detectors from + # this visit. + genericElements = mapTemplate["EXPOSURE/DEVICE/base"]["Elements"] + mapElements = [] + instrument = visitSummary[0].getVisitInfo().instrumentLabel + # Go through the generic map components to build the names + # of the specific maps for this extension. + for component in genericElements: + elements = mapTemplate[component]["Elements"] + for element in elements: + # TODO: DM-42519, gbdes sets the "BAND" to the + # instrument name currently. This will need to be + # disambiguated if we run on multiple bands at + # once. + element = element.replace("BAND", str(instrument)) + element = element.replace("EXPOSURE", str(visit)) + element = element.replace("DEVICE", str(detector)) + mapElements.append(element) mapDict = {} for m, mapElement in enumerate(mapElements): mapType = wcsf.mapCollection.getMapType(mapElement) diff --git a/tests/test_gbdesAstrometricFit.py b/tests/test_gbdesAstrometricFit.py index 2caf0f67..6420adc3 100644 --- a/tests/test_gbdesAstrometricFit.py +++ b/tests/test_gbdesAstrometricFit.py @@ -539,6 +539,31 @@ def test_run(self): np.testing.assert_array_less(absDiffX, 1e-7) np.testing.assert_array_less(absDiffY, 1e-7) + def test_missingWcs(self): + """Test that task does not fail when the input WCS is None for one + extension and that the fit WCS for that extension returns a finite + result. + """ + inputVisitSummary = self.inputVisitSummary.copy() + # Set one WCS to be None + testVisit = 0 + testDetector = 20 + inputVisitSummary[testVisit][testDetector].setWcs(None) + + outputs = self.task.run( + self.inputCatalogRefs, + inputVisitSummary, + instrumentName=self.instrumentName, + refEpoch=self.refEpoch, + refObjectLoader=self.refObjectLoader, + ) + + # Check that the fit WCS for the extension with input WCS=None returns + # finite sky values. + testWcs = outputs.outputWCSs[self.testVisits[testVisit]][testDetector].getWcs() + testSky = testWcs.pixelToSky(0, 0) + self.assertTrue(testSky.isFinite()) + def setup_module(module): lsst.utils.tests.init()