From 69415583512ef217ba706b9ebc1ba6796d09ef07 Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Fri, 15 Dec 2023 11:26:45 -0800 Subject: [PATCH] Recover WCS for input with astrometry failures --- python/lsst/drp/tasks/gbdesAstrometricFit.py | 79 +++++++++++++++----- tests/test_gbdesAstrometricFit.py | 25 +++++++ 2 files changed, 86 insertions(+), 18 deletions(-) diff --git a/python/lsst/drp/tasks/gbdesAstrometricFit.py b/python/lsst/drp/tasks/gbdesAstrometricFit.py index d7966c02..44914b5d 100644 --- a/python/lsst/drp/tasks/gbdesAstrometricFit.py +++ b/python/lsst/drp/tasks/gbdesAstrometricFit.py @@ -542,7 +542,7 @@ def run( self.log.info("Fit the WCSs") # Set up a YAML-type string using the config variables and a sample # visit - inputYAML = self.make_yaml(inputVisitSummaries[0]) + inputYAML, mapTemplate = self.make_yaml(inputVisitSummaries[0]) # Set the verbosity level for WCSFit from the task log level. # TODO: DM-36850, Add lsst.log to gbdes so that log messages are @@ -590,7 +590,7 @@ def run( ) self.log.info("WCS fitting done") - outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo) + outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo, mapTemplate=mapTemplate) outputCatalog = wcsf.getOutputCatalog() starCatalog = wcsf.getStarCatalog() modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None @@ -630,6 +630,7 @@ def _prep_sky(self, inputVisitSummaries, epoch, fieldName="Field"): detectorCorners = [ lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees).getVector() for (ra, dec) in zip(visSum["raCorners"].ravel(), visSum["decCorners"].ravel()) + if (np.isfinite(ra) and (np.isfinite(dec))) ] allDetectorCorners.extend(detectorCorners) boundingCircle = lsst.sphgeom.ConvexPolygon.convexHull(allDetectorCorners).getBoundingCircle() @@ -738,6 +739,24 @@ def _get_exposure_info( for row in visitSummary: detector = row["id"] + + wcs = row.getWcs() + if wcs is None: + self.log.warning( + "WCS is None for visit %d, detector %d: this extension (visit/detector) will be " + "dropped.", + visit, + detector, + ) + continue + else: + wcsRA = wcs.getSkyOrigin().getRa().asRadians() + wcsDec = wcs.getSkyOrigin().getDec().asRadians() + tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec) + mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC") + gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint) + wcss.append(gbdes_wcs) + if detector not in detectors: detectors.append(detector) detectorBounds = wcsfit.Bounds( @@ -752,14 +771,6 @@ def _get_exposure_info( extensionDetectors.append(detector) extensionType.append("SCIENCE") - wcs = row.getWcs() - wcsRA = wcs.getSkyOrigin().getRa().asRadians() - wcsDec = wcs.getSkyOrigin().getDec().asRadians() - tangentPoint = wcsfit.Gnomonic(wcsRA, wcsDec) - mapping = wcs.getFrameDict().getMapping("PIXELS", "IWC") - gbdes_wcs = wcsfit.Wcs(wcsfit.ASTMap(mapping), tangentPoint) - wcss.append(gbdes_wcs) - fieldNumbers = list(np.ones(len(exposureNames), dtype=int) * fieldNumber) instrumentNumbers = list(np.ones(len(exposureNames), dtype=int) * instrumentNumber) @@ -999,9 +1010,15 @@ class `wcsfit.FoFClass`, associating them into matches as you go. goodInds = selected.selected & goodShapes isStar = np.ones(goodInds.sum()) - extensionIndex = np.flatnonzero( + findExtension = np.flatnonzero( (extensionInfo.visit == visit) & (extensionInfo.detector == detector) - )[0] + ) + if len(findExtension) == 0: + # This extension does not have information necessary for + # fit. Skip the detections from this detector for this + # visit. + continue + extensionIndex = findExtension[0] detectorIndex = extensionInfo.detectorIndex[extensionIndex] visitIndex = extensionInfo.visitIndex[extensionIndex] @@ -1087,6 +1104,8 @@ def make_yaml(self, inputVisitSummary, inputFile=None): ------- inputYAML : `wcsfit.YAMLCollector` YAML object containing the model description. + inputDict : `dict` [`str`, `str`] + Dictionary containing the model description. """ if inputFile is not None: inputYAML = wcsfit.YAMLCollector(inputFile, "PixelMapCollection") @@ -1137,7 +1156,7 @@ def make_yaml(self, inputVisitSummary, inputFile=None): inputYAML.addInput(yaml.dump(inputDict)) inputYAML.addInput("Identity:\n Type: Identity\n") - return inputYAML + return inputYAML, inputDict def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, columns): """Add science sources to the wcsfit.WCSFit object. @@ -1164,9 +1183,16 @@ def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, col for detector in detectors: detectorSources = inputCatalog[inputCatalog["detector"] == detector] - extensionIndex = np.flatnonzero( + findExtension = np.flatnonzero( (extensionInfo.visit == visit) & (extensionInfo.detector == detector) - )[0] + ) + if len(findExtension) == 0: + # This extension does not have information necessary for + # fit. Skip the detections from this detector for this + # visit. + continue + extensionIndex = findExtension[0] + sourceCat = detectorSources[sourceIndices[extensionIndex]] xCov = sourceCat["xErr"] ** 2 @@ -1288,7 +1314,7 @@ def _make_afw_wcs(self, mapDict, centerRA, centerDec, doNormalizePixels=False, x outWCS = afwgeom.SkyWcs(frameDict) return outWCS - def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo): + def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None): """Make a WCS object out of the WCS models. Parameters @@ -1335,8 +1361,25 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo): for d, detector in enumerate(visitSummary["id"]): mapName = f"{visit}/{detector}" - - mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base") + if mapName in wcsf.mapCollection.allMapNames(): + mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base") + else: + # This extension was not fit, but the WCS can be recovered + # using the maps fit from sources on other visits but the + # same detector and from sources on other detectors from + # this visit. + genericElements = mapTemplate["EXPOSURE/DEVICE/base"]["Elements"] + mapElements = [] + instrument = visitSummary[0].getVisitInfo().instrumentLabel + # Go through the generic map components to build the names + # of the specific maps for this extension. + for component in genericElements: + elements = mapTemplate[component]["Elements"] + for element in elements: + element = element.replace("BAND", str(instrument)) + element = element.replace("EXPOSURE", str(visit)) + element = element.replace("DEVICE", str(detector)) + mapElements.append(element) mapDict = {} for m, mapElement in enumerate(mapElements): mapType = wcsf.mapCollection.getMapType(mapElement) diff --git a/tests/test_gbdesAstrometricFit.py b/tests/test_gbdesAstrometricFit.py index 2caf0f67..6420adc3 100644 --- a/tests/test_gbdesAstrometricFit.py +++ b/tests/test_gbdesAstrometricFit.py @@ -539,6 +539,31 @@ def test_run(self): np.testing.assert_array_less(absDiffX, 1e-7) np.testing.assert_array_less(absDiffY, 1e-7) + def test_missingWcs(self): + """Test that task does not fail when the input WCS is None for one + extension and that the fit WCS for that extension returns a finite + result. + """ + inputVisitSummary = self.inputVisitSummary.copy() + # Set one WCS to be None + testVisit = 0 + testDetector = 20 + inputVisitSummary[testVisit][testDetector].setWcs(None) + + outputs = self.task.run( + self.inputCatalogRefs, + inputVisitSummary, + instrumentName=self.instrumentName, + refEpoch=self.refEpoch, + refObjectLoader=self.refObjectLoader, + ) + + # Check that the fit WCS for the extension with input WCS=None returns + # finite sky values. + testWcs = outputs.outputWCSs[self.testVisits[testVisit]][testDetector].getWcs() + testSky = testWcs.pixelToSky(0, 0) + self.assertTrue(testSky.isFinite()) + def setup_module(module): lsst.utils.tests.init()