From a107c603cd6d73ef69c4abeccee25b8e7d1d0c2b Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Tue, 13 Feb 2024 08:16:24 -0800 Subject: [PATCH] Recover wcs fix, plus other review changes --- python/lsst/drp/tasks/gbdesAstrometricFit.py | 72 +++++++++++--------- tests/test_gbdesAstrometricFit.py | 28 +++++++- 2 files changed, 66 insertions(+), 34 deletions(-) diff --git a/python/lsst/drp/tasks/gbdesAstrometricFit.py b/python/lsst/drp/tasks/gbdesAstrometricFit.py index 4c2800ad..d3516abc 100644 --- a/python/lsst/drp/tasks/gbdesAstrometricFit.py +++ b/python/lsst/drp/tasks/gbdesAstrometricFit.py @@ -120,18 +120,11 @@ def _make_ref_covariance_matrix( for i, pi in enumerate(positionParameters): for j, pj in enumerate(positionParameters): if i == j: - cov[:, k] = ( - ((refCat[f"{pi}Err"].value) ** 2 * inputUnit**2).to(units[j] * units[j]).value - ) + cov[:, k] = ((refCat[f"{pi}Err"].value) ** 2 * inputUnit**2).to(units[j] * units[j]).value elif i > j: - cov[:, k] = (refCat[f"{pj}_{pi}_Cov"].value * inputUnit**2).to_value( - units[i] * units[j] - ) + cov[:, k] = (refCat[f"{pj}_{pi}_Cov"].value * inputUnit**2).to_value(units[i] * units[j]) else: - cov[:, k] = (refCat[f"{pi}_{pj}_Cov"].value * inputUnit**2).to_value( - units[i] * units[j] - ) - + cov[:, k] = (refCat[f"{pi}_{pj}_Cov"].value * inputUnit**2).to_value(units[i] * units[j]) k += 1 return cov @@ -306,6 +299,17 @@ class GbdesAstrometricFitConfig( target=ReferenceSourceSelectorTask, doc="How to down-select the loaded astrometry reference catalog.", ) + referenceFilter = pexConfig.Field( + dtype=str, + doc="Name of filter to load from reference catalog. This is a required argument, although the values" + "returned are not used.", + default="phot_g_mean", + ) + applyRefCatProperMotion = pexConfig.Field( + dtype=bool, + doc="Apply proper motion to shift reference catalog to epoch of observations.", + default=True, + ) matchRadius = pexConfig.Field( doc="Matching tolerance between associated objects (arcseconds).", dtype=float, default=1.0 ) @@ -461,12 +465,9 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputRefCats = np.array([inputRefCat.dataId["htm7"] for inputRefCat in inputs["referenceCatalog"]]) inputs["referenceCatalog"] = [inputs["referenceCatalog"][v] for v in inputRefCats.argsort()] - sampleRefCat = inputs["referenceCatalog"][0].get() - refEpoch = sampleRefCat[0]["epoch"] - refConfig = LoadReferenceObjectsConfig() - refConfig.anyFilterMapsToThis = "phot_g_mean" - refConfig.requireProperMotion = True + if self.config.applyRefCatProperMotion: + refConfig.requireProperMotion = True refObjectLoader = ReferenceObjectLoader( dataIds=[ref.datasetRef.dataId for ref in inputRefCatRefs], refCats=inputs.pop("referenceCatalog"), @@ -474,9 +475,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): log=self.log, ) - output = self.run( - **inputs, instrumentName=instrumentName, refEpoch=refEpoch, refObjectLoader=refObjectLoader - ) + output = self.run(**inputs, instrumentName=instrumentName, refObjectLoader=refObjectLoader) wcsOutputRefDict = {outWcsRef.dataId["visit"]: outWcsRef for outWcsRef in outputRefs.outputWcs} for visit, outputWcs in output.outputWcss.items(): @@ -940,13 +939,17 @@ def _load_refcat( refCovariance : `list` [`float`] Flattened output covariance matrix. """ - formattedEpoch = astropy.time.Time(epoch, format="mjd") + if self.config.applyRefCatProperMotion: + formattedEpoch = astropy.time.Time(epoch, format="mjd") + else: + formattedEpoch = None - refFilter = refObjectLoader.config.anyFilterMapsToThis if region is not None: - skyRegion = refObjectLoader.loadRegion(region, refFilter, epoch=formattedEpoch) + skyRegion = refObjectLoader.loadRegion(region, self.config.referenceFilter, epoch=formattedEpoch) elif (center is not None) and (radius is not None): - skyRegion = refObjectLoader.loadSkyCircle(center, radius, refFilter, epoch=formattedEpoch) + skyRegion = refObjectLoader.loadSkyCircle( + center, radius, self.config.referenceFilter, epoch=formattedEpoch + ) else: raise RuntimeError("Either `region` or `center` and `radius` must be set.") @@ -1503,6 +1506,11 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None # Set up the schema for the output catalogs schema = lsst.afw.table.ExposureTable.makeMinimalSchema() schema.addField("visit", type="L", doc="Visit number") + schema.addField( + "recoveredWcs", + type="Flag", + doc="Input WCS missing, output recovered from other input visit/detectors.", + ) # Pixels will need to be rescaled before going into the mappings sampleDetector = visitSummaryTables[0][0] @@ -1523,10 +1531,11 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None catalog.resize(len(exposureInfo.detectors)) catalog["visit"] = visit - for d, detector in enumerate(visitSummary["id"]): + for d, detector in enumerate(exposureInfo.detectors): mapName = f"{visit}/{detector}" if mapName in wcsf.mapCollection.allMapNames(): mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base") + catalog[d]["recoveredWcs"] = False else: # This extension was not fit, but the WCS can be recovered # using the maps fit from sources on other visits but the @@ -1548,6 +1557,7 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None element = element.replace("EXPOSURE", str(visit)) element = element.replace("DEVICE", str(detector)) mapElements.append(element) + catalog[d]["recoveredWcs"] = True mapDict = {} for m, mapElement in enumerate(mapElements): mapType = wcsf.mapCollection.getMapType(mapElement) @@ -1769,12 +1779,9 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): [inputs["isolatedStarCatalogs"][t] for t in inputIsolatedStarSourceTracts.argsort()] ) - sampleRefCat = inputs["referenceCatalog"][0].get() - refEpoch = sampleRefCat[0]["epoch"] - refConfig = LoadReferenceObjectsConfig() - refConfig.anyFilterMapsToThis = "phot_g_mean" - refConfig.requireProperMotion = True + if self.config.applyRefCatProperMotion: + refConfig.requireProperMotion = True refObjectLoader = ReferenceObjectLoader( dataIds=[ref.datasetRef.dataId for ref in inputRefCatRefs], refCats=inputs.pop("referenceCatalog"), @@ -1782,9 +1789,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): log=self.log, ) - output = self.run( - **inputs, instrumentName=instrumentName, refEpoch=refEpoch, refObjectLoader=refObjectLoader - ) + output = self.run(**inputs, instrumentName=instrumentName, refObjectLoader=refObjectLoader) for outputRef in outputRefs.outputWcs: visit = outputRef.dataId["visit"] @@ -1866,7 +1871,7 @@ def run( self.log.info("Fit the WCSs") # Set up a YAML-type string using the config variables and a sample # visit - inputYaml = self.make_yaml(inputVisitSummaries[0]) + inputYaml, mapTemplate = self.make_yaml(inputVisitSummaries[0]) # Set the verbosity level for WCSFit from the task log level. # TODO: DM-36850, Add lsst.log to gbdes so that log messages are @@ -1911,7 +1916,7 @@ def run( ) self.log.info("WCS fitting done") - outputWcss = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo) + outputWcss = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo, mapTemplate=mapTemplate) outputCatalog = wcsf.getOutputCatalog() starCatalog = wcsf.getStarCatalog() modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None @@ -1955,6 +1960,7 @@ def _prep_sky(self, inputVisitSummaries): detectorCorners = [ lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees).getVector() for (ra, dec) in zip(visSum["raCorners"].ravel(), visSum["decCorners"].ravel()) + if (np.isfinite(ra) and (np.isfinite(dec))) ] allDetectorCorners.append(detectorCorners) diff --git a/tests/test_gbdesAstrometricFit.py b/tests/test_gbdesAstrometricFit.py index 931d1708..f0dcb1a9 100644 --- a/tests/test_gbdesAstrometricFit.py +++ b/tests/test_gbdesAstrometricFit.py @@ -638,7 +638,7 @@ def test_missingWcs(self): # Check that the fit WCS for the extension with input WCS=None returns # finite sky values. - testWcs = outputs.outputWCSs[self.testVisits[testVisit]][testDetector].getWcs() + testWcs = outputs.outputWcss[self.testVisits[testVisit]][testDetector].getWcs() testSky = testWcs.pixelToSky(0, 0) self.assertTrue(testSky.isFinite()) @@ -941,6 +941,32 @@ def test_make_outputs(self): self.assertAlmostEqual(np.mean(dDec), 0) self.assertAlmostEqual(np.std(dDec), 0) + def test_missingWcs(self): + """Test that task does not fail when the input WCS is None for one + extension and that the fit WCS for that extension returns a finite + result. + """ + inputVisitSummary = self.inputVisitSummary.copy() + # Set one WCS to be None + testVisit = 0 + testDetector = 20 + inputVisitSummary[testVisit][testDetector].setWcs(None) + + outputs = self.task.run( + inputVisitSummary, + self.isolatedStarSources, + self.isolatedStarCatalogs, + instrumentName=self.instrumentName, + refEpoch=self.refEpoch, + refObjectLoader=self.refObjectLoader, + ) + + # Check that the fit WCS for the extension with input WCS=None returns + # finite sky values. + testWcs = outputs.outputWcss[self.testVisits[testVisit]][testDetector].getWcs() + testSky = testWcs.pixelToSky(0, 0) + self.assertTrue(testSky.isFinite()) + def setup_module(module): lsst.utils.tests.init()