From 9e9e448e35741beb15c8b23fd9758a9e7abafe57 Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Thu, 18 Jul 2024 12:54:52 -0700 Subject: [PATCH] Add option to correct for DCR --- python/lsst/drp/tasks/gbdesAstrometricFit.py | 210 +++++++++++++++++-- tests/test_gbdesAstrometricFit.py | 78 +++++++ 2 files changed, 265 insertions(+), 23 deletions(-) diff --git a/python/lsst/drp/tasks/gbdesAstrometricFit.py b/python/lsst/drp/tasks/gbdesAstrometricFit.py index bb4ddaf1..d13fb3a3 100644 --- a/python/lsst/drp/tasks/gbdesAstrometricFit.py +++ b/python/lsst/drp/tasks/gbdesAstrometricFit.py @@ -279,6 +279,12 @@ class GbdesAstrometricFitConnections( deferLoad=True, multiple=True, ) + colorCatalog = pipeBase.connectionTypes.Input( + doc="The catalog of magnitudes to match to input sources.", + name="fgcm_Cycle4_StandardStars", + storageClass="SimpleCatalog", + dimensions=("instrument",), + ) inputCameraModel = pipeBase.connectionTypes.PrerequisiteInput( doc="Camera parameters to use for 'device' part of model", name="gbdesAstrometricFit_cameraModel", @@ -326,6 +332,12 @@ class GbdesAstrometricFitConnections( storageClass="ArrowNumpyDict", dimensions=("instrument", "physical_filter"), ) + dcrCoefficients = pipeBase.connectionTypes.Output( + doc="Per-visit coefficients for DCR correction.", + name="gbdesAstrometricFit_dcrCoefficients", + storageClass="ArrowNumpyDict", + dimensions=("instrument", "skymap", "tract", "physical_filter"), + ) def getSpatialBoundsConnections(self): return ("inputVisitSummaries",) @@ -333,6 +345,9 @@ def getSpatialBoundsConnections(self): def __init__(self, *, config=None): super().__init__(config=config) + if not self.config.useColor: + self.inputs.remove("colorCatalog") + self.outputs.remove("dcrCoefficients") if not self.config.saveModelParams: self.outputs.remove("modelParams") if not self.config.useInputCameraModel: @@ -393,6 +408,22 @@ class GbdesAstrometricFitConfig( doc="Systematic error padding added in quadrature for the reference catalog (marcsec).", default=0.0, ) + useColor = pexConfig.Field( + dtype=bool, + doc="Use color information to correct for differential chromatic refraction.", + default=False, + ) + color = pexConfig.ListField( + dtype=str, + doc="The bands to use for calculating color.", + default=["g", "i"], + listCheck=(lambda x: (len(x) == 2) and (len(set(x)) == len(x))), + ) + referenceColor = pexConfig.Field( + dtype=float, + doc="The color for which DCR is defined as zero.", + default=0.61, + ) modelComponents = pexConfig.ListField( dtype=str, doc=( @@ -487,7 +518,8 @@ def validate(self): # Check if all components of the device and exposure models are # supported. for component in self.deviceModel: - if not (("poly" in component.lower()) or ("identity" in component.lower())): + mapping = component.split("/")[-1] + if mapping not in ["poly", "identity"]: raise pexConfig.FieldValidationError( GbdesAstrometricFitConfig.deviceModel, self, @@ -495,7 +527,8 @@ def validate(self): ) for component in self.exposureModel: - if not (("poly" in component.lower()) or ("identity" in component.lower())): + mapping = component.split("/")[-1] + if mapping not in ["poly", "identity", "dcr"]: raise pexConfig.FieldValidationError( GbdesAstrometricFitConfig.exposureModel, self, @@ -549,7 +582,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): log=self.log, ) - output = self.run(**inputs, instrumentName=instrumentName, refObjectLoader=refObjectLoader) + if self.config.useColor: + colorCatalog = inputs.pop("colorCatalog") + else: + colorCatalog = None + + output = self.run( + **inputs, + instrumentName=instrumentName, + refObjectLoader=refObjectLoader, + colorCatalog=colorCatalog, + ) wcsOutputRefDict = {outWcsRef.dataId["visit"]: outWcsRef for outWcsRef in outputRefs.outputWcs} for visit, outputWcs in output.outputWcss.items(): @@ -560,6 +603,8 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): butlerQC.put(output.modelParams, outputRefs.modelParams) if self.config.saveCameraModel: butlerQC.put(output.cameraModelParams, outputRefs.outputCameraModel) + if self.config.useColor: + butlerQC.put(output.colorParams, outputRefs.dcrCoefficients) def run( self, @@ -569,6 +614,7 @@ def run( refEpoch=None, refObjectLoader=None, inputCameraModel=None, + colorCatalog=None, ): """Run the WCS fit for a given set of visits @@ -585,9 +631,11 @@ def run( Epoch of the reference objects in MJD. refObjectLoader : instance of `lsst.meas.algorithms.loadReferenceObjects.ReferenceObjectLoader` - Referencef object loader instance. + Reference object loader instance. inputCameraModel : `dict` [`str`, `np.ndarray`], optional Parameters to use for the device part of the model. + colorCatalog : `lsst.afw.table.SimpleCatalog` + Catalog containing object coordinates and magnitudes. Returns ------- @@ -606,6 +654,8 @@ def run( ``cameraModelParams`` : `dict` [`str`, `np.ndarray`] Parameters of the device part of the model, in the format needed as input for future runs. + ``colorParams`` : `dict` [`int`, `np.ndarray`] + DCR parameters fit in RA and Dec directions for each visit. """ self.log.info("Gather instrument, exposure, and field info") @@ -687,6 +737,8 @@ def run( # Add the science and reference sources self._add_objects(wcsf, inputCatalogRefs, sourceIndices, extensionInfo, usedColumns) self._add_ref_objects(wcsf, refObjects, refCovariance, extensionInfo) + if self.config.useColor: + self._add_color_objects(wcsf, colorCatalog) # There must be at least as many sources per visit as the number of # free parameters in the per-visit mapping. Set minFitExposures to be @@ -701,7 +753,7 @@ def run( ) self.log.info("WCS fitting done") - outputWcss, cameraParams = self._make_outputs( + outputWcss, cameraParams, colorParams = self._make_outputs( wcsf, inputVisitSummaries, exposureInfo, @@ -719,6 +771,7 @@ def run( starCatalog=starCatalog, modelParams=modelParams, cameraModelParams=cameraParams, + colorParams=colorParams, ) def _prep_sky(self, inputVisitSummaries, epoch, fieldName="Field"): @@ -1346,7 +1399,8 @@ def make_yaml(self, inputVisitSummary, inputFile=None, inputCameraModel=None): deviceModel = {"Type": "Composite", "Elements": self.config.deviceModel.list()} inputDict["BAND/DEVICE"] = deviceModel for component in self.config.deviceModel: - if "poly" in component.lower(): + mapping = component.split("/")[-1] + if mapping == "poly": componentDict = { "Type": "Poly", "XPoly": {"OrderX": self.config.devicePolyOrder, "SumOrder": True}, @@ -1356,7 +1410,7 @@ def make_yaml(self, inputVisitSummary, inputFile=None, inputCameraModel=None): "YMin": yMin, "YMax": yMax, } - elif "identity" in component.lower(): + elif mapping == "identity": componentDict = {"Type": "Identity"} inputDict[component] = componentDict @@ -1388,17 +1442,27 @@ def make_yaml(self, inputVisitSummary, inputFile=None, inputCameraModel=None): } inputDict[key] = mapDict - exposureModel = {"Type": "Composite", "Elements": self.config.exposureModel.list()} + exposureModelComponents = self.config.exposureModel.list() + if self.config.useColor: + exposureModelComponents.append("EXPOSURE/dcr") + exposureModel = {"Type": "Composite", "Elements": exposureModelComponents} inputDict["EXPOSURE"] = exposureModel - for component in self.config.exposureModel: - if "poly" in component.lower(): + for component in exposureModelComponents: + mapping = component.split("/")[-1] + if mapping == "poly": componentDict = { "Type": "Poly", "XPoly": {"OrderX": self.config.exposurePolyOrder, "SumOrder": "true"}, "YPoly": {"OrderX": self.config.exposurePolyOrder, "SumOrder": "true"}, } - elif "identity" in component.lower(): + elif mapping == "identity": componentDict = {"Type": "Identity"} + elif mapping == "dcr": + componentDict = { + "Type": "Color", + "Reference": self.config.referenceColor, + "Function": {"Type": "Constant"}, + } inputDict[component] = componentDict @@ -1466,7 +1530,14 @@ def _add_objects(self, wcsf, inputCatalogRefs, sourceIndices, extensionInfo, col "xyCov": xyCov.to_numpy(), } - wcsf.setObjects(extensionIndex, d, "x", "y", ["xCov", "yCov", "xyCov"]) + wcsf.setObjects( + extensionIndex, + d, + "x", + "y", + ["xCov", "yCov", "xyCov"], + defaultColor=self.config.referenceColor, + ) def _add_ref_objects(self, wcsf, refObjects, refCovariance, extensionInfo, fieldIndex=0): """Add reference sources to the wcsfit.WCSFit object. @@ -1515,6 +1586,43 @@ def _add_ref_objects(self, wcsf, refObjects, refCovariance, extensionInfo, field else: wcsf.setObjects(extensionIndex, refObjects, "ra", "dec", ["raCov", "decCov", "raDecCov"]) + def _add_color_objects(self, wcsf, colorCatalog): + """Associate input matches with objects in color catalog and set their + color value. + + Parameters + ---------- + wcsf : `wcsfit.WCSFit` + WCSFit object, assumed to have fit model. + colorCatalog : `lsst.afw.table.SimpleCatalog` + Catalog containing object coordinates and magnitudes. + """ + + # Get current best position for matches + starCat = wcsf.getStarCatalog() + + # TODO: DM-45650, update how the colors are read in here. + catalogBands = colorCatalog.metadata.getArray("BANDS") + colorInd1 = catalogBands.index(self.config.color[0]) + colorInd2 = catalogBands.index(self.config.color[1]) + colors = colorCatalog["mag_std_noabs"][:, colorInd1] - colorCatalog["mag_std_noabs"][:, colorInd2] + goodInd = (colorCatalog["mag_std_noabs"][:, colorInd1] != 99.0) & ( + colorCatalog["mag_std_noabs"][:, colorInd2] != 99.0 + ) + + with Matcher(np.array(starCat["starX"]), np.array(starCat["starY"])) as matcher: + idx, idx_starCat, idx_colorCat, d = matcher.query_radius( + (colorCatalog[goodInd]["coord_ra"] * u.radian).to(u.degree).value, + (colorCatalog[goodInd]["coord_dec"] * u.radian).to(u.degree).value, + self.config.matchRadius / 3600.0, + return_indices=True, + ) + + matchesWithColor = starCat["starMatchID"][idx_starCat] + matchColors = np.ones(len(matchesWithColor)) * self.config.referenceColor + matchColors = colors[goodInd][idx_colorCat] + wcsf.setColors(matchesWithColor, matchColors) + def _make_afw_wcs(self, mapDict, centerRA, centerDec, doNormalizePixels=False, xScale=1, yScale=1): """Make an `lsst.afw.geom.SkyWcs` from a dictionary of mappings. @@ -1625,6 +1733,8 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate, inp cameraParams : `dict` [`str`, `np.ndarray`], optional Parameters for the device part of the model in the format needed when used as input for future runs. + colorFits : `dict` [`int`, `np.ndarray`], optional + DCR parameters fit in RA and Dec directions for each visit. """ # Get the parameters of the fit models mapParams = wcsf.mapCollection.getParamDict() @@ -1659,10 +1769,15 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate, inp yscale = sampleDetector["bbox_max_y"] - sampleDetector["bbox_min_y"] catalogs = {} + colorFits = {} for v, visitSummary in enumerate(visitSummaryTables): visit = visitSummary[0]["visit"] - visitMap = wcsf.mapCollection.orderAtoms(f"{visit}")[0] + visitMaps = wcsf.mapCollection.orderAtoms(f"{visit}") + if self.config.useColor: + colorMap = visitMaps.pop(visitMaps.index(f"{visit}/dcr")) + colorFits[visit] = mapParams[colorMap] + visitMap = visitMaps[0] visitMapType = wcsf.mapCollection.getMapType(visitMap) if (visitMap not in mapParams) and (visitMapType != "Identity"): self.log.warning("Visit %d was dropped because of an insufficient number of sources.", visit) @@ -1703,6 +1818,9 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate, inp mapDict = {} for m, mapElement in enumerate(mapElements): mapType = wcsf.mapCollection.getMapType(mapElement) + if mapType == "Color": + # DCR fit should not go into the generic WCS. + continue mapDict[mapElement] = {"Type": mapType} if mapType == "Poly": @@ -1724,8 +1842,13 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate, inp catalog[d].setWcs(outWCS) catalog.sort() catalogs[visit] = catalog + if self.config.useColor: + colorVisits = np.array(list(colorFits.keys())) + colorRA = np.array([colorFits[vis][0] for vis in colorVisits]) + colorDec = np.array([colorFits[vis][1] for vis in colorVisits]) + colorFits = {"visit": colorVisits, "raCoefficient": colorRA, "decCoefficient": colorDec} - return catalogs, cameraParams + return catalogs, cameraParams, colorFits def _compute_model_params(self, wcsf): """Get the WCS model parameters and covariance and convert to a @@ -1832,6 +1955,12 @@ class GbdesGlobalAstrometricFitConnections( deferLoad=True, multiple=True, ) + colorCatalog = pipeBase.connectionTypes.Input( + doc="The catalog of magnitudes to match to input sources.", + name="fgcm_Cycle4_StandardStars", + storageClass="SimpleCatalog", + dimensions=("instrument",), + ) isolatedStarSources = pipeBase.connectionTypes.Input( doc="Catalog of matched sources.", name="isolated_star_sources", @@ -1903,6 +2032,12 @@ class GbdesGlobalAstrometricFitConnections( storageClass="ArrowNumpyDict", dimensions=("instrument", "physical_filter"), ) + dcrCoefficients = pipeBase.connectionTypes.Output( + doc="Per-visit coefficients for DCR correction.", + name="gbdesGlobalAstrometricFit_dcrCoefficients", + storageClass="ArrowNumpyDict", + dimensions=("instrument", "physical_filter"), + ) def getSpatialBoundsConnections(self): return ("inputVisitSummaries",) @@ -1910,6 +2045,9 @@ def getSpatialBoundsConnections(self): def __init__(self, *, config=None): super().__init__(config=config) + if not self.config.useColor: + self.inputs.remove("colorCatalog") + self.outputs.remove("dcrCoefficients") if not self.config.saveModelParams: self.outputs.remove("modelParams") if not self.config.useInputCameraModel: @@ -1984,8 +2122,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): config=refConfig, log=self.log, ) + if self.config.useColor: + colorCatalog = inputs.pop("colorCatalog") + else: + colorCatalog = None - output = self.run(**inputs, instrumentName=instrumentName, refObjectLoader=refObjectLoader) + output = self.run( + **inputs, + instrumentName=instrumentName, + refObjectLoader=refObjectLoader, + colorCatalog=colorCatalog, + ) for outputRef in outputRefs.outputWcs: visit = outputRef.dataId["visit"] @@ -1996,6 +2143,8 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): butlerQC.put(output.modelParams, outputRefs.modelParams) if self.config.saveCameraModel: butlerQC.put(output.cameraModelParams, outputRefs.outputCameraModel) + if self.config.useColor: + butlerQC.put(output.colorParams, outputRefs.dcrCoefficients) def run( self, @@ -2006,6 +2155,7 @@ def run( refEpoch=None, refObjectLoader=None, inputCameraModel=None, + colorCatalog=None, ): """Run the WCS fit for a given set of visits @@ -2027,6 +2177,8 @@ def run( Reference object loader instance. inputCameraModel : `dict` [`str`, `np.ndarray`], optional Parameters to use for the device part of the model. + colorCatalog : `lsst.afw.table.SimpleCatalog` + Catalog containing object coordinates and magnitudes. Returns ------- @@ -2045,6 +2197,8 @@ def run( ``cameraModelParams`` : `dict` [`str`, `np.ndarray`] Parameters of the device part of the model, in the format needed as input for future runs. + ``colorParams`` : `dict` [`int`, `np.ndarray`] + DCR parameters fit in RA and Dec directions for each visit. """ self.log.info("Gather instrument, exposure, and field info") @@ -2114,6 +2268,8 @@ def run( self._add_ref_objects( wcsf, allRefObjects[f], allRefCovariances[f], extensionInfo, fieldIndex=-1 * f ) + if self.config.useColor: + self._add_color_objects(wcsf, colorCatalog) # Do the WCS fit wcsf.fit( @@ -2121,7 +2277,7 @@ def run( ) self.log.info("WCS fitting done") - outputWcss, cameraParams = self._make_outputs( + outputWcss, cameraParams, colorParams = self._make_outputs( wcsf, inputVisitSummaries, exposureInfo, @@ -2139,6 +2295,7 @@ def run( starCatalog=starCatalog, modelParams=modelParams, cameraModelParams=cameraParams, + colorParams=colorParams, ) def _prep_sky(self, inputVisitSummaries): @@ -2303,22 +2460,22 @@ def _associate_from_isolated_sources( with Matcher( isolatedStarCatalog["ra"].to_numpy(), isolatedStarCatalog["dec"].to_numpy() ) as matcher: - idx, i1, i2, d = matcher.query_radius( + idx, idx_isoStarCat, idx_refObjects, d = matcher.query_radius( np.array(regionRefObjects["ra"]), np.array(regionRefObjects["dec"]), self.config.matchRadius / 3600.0, return_indices=True, ) - refSort = np.searchsorted(isolatedStarSources["obj_index"], i1) - refDetector = np.ones(len(i1)) * -1 + refSort = np.searchsorted(isolatedStarSources["obj_index"], idx_isoStarCat) + refDetector = np.ones(len(idx_isoStarCat)) * -1 # The "visit" for the reference catalogs is the field times -1. - refVisit = np.ones(len(i1)) * f * -1 + refVisit = np.ones(len(idx_isoStarCat)) * f * -1 allVisits = np.insert(allVisits, refSort, refVisit) allDetectors = np.insert(allDetectors, refSort, refDetector) - allObjectIndices = np.insert(allObjectIndices, refSort, i1) - issIndices = np.insert(issIndices, refSort, i2) + allObjectIndices = np.insert(allObjectIndices, refSort, idx_isoStarCat) + issIndices = np.insert(issIndices, refSort, idx_refObjects) # Loop through the associated sources to convert them to the gbdes # format, which requires the extension index, the source's index in @@ -2407,7 +2564,14 @@ def _add_objects(self, wcsf, sourceDict, extensionInfo): "yCov": np.array(sourceCat["yCov"]), "xyCov": np.array(sourceCat["xyCov"]), } - wcsf.setObjects(extensionIndex, d, "x", "y", ["xCov", "yCov", "xyCov"]) + wcsf.setObjects( + extensionIndex, + d, + "x", + "y", + ["xCov", "yCov", "xyCov"], + defaultColor=self.config.referenceColor, + ) class GbdesGlobalAstrometricMultibandFitConnections( diff --git a/tests/test_gbdesAstrometricFit.py b/tests/test_gbdesAstrometricFit.py index 3690fee8..66a78665 100644 --- a/tests/test_gbdesAstrometricFit.py +++ b/tests/test_gbdesAstrometricFit.py @@ -149,6 +149,8 @@ def setUpClass(cls): # Make source catalogs: cls.inputCatalogRefs = cls._make_sourceCat(starIds, starRAs, starDecs, trueWCSs, inScienceFraction) + cls.colorCatalog = cls._make_colors(starRAs, starDecs) + cls.outputs = cls.task.run( cls.inputCatalogRefs, cls.inputVisitSummary, @@ -436,6 +438,48 @@ def _make_wcs(cls, model, inputVisitSummaries): return catalogs + @classmethod + def _make_colors(cls, starRas, starDecs): + """Make a catalog with the star magnitudes. + + Parameters + ---------- + starRas : `np.ndarray` [`float`] + RAs of the simulated stars + starDecs : `np.ndarray` [`float`] + Decs of the simulated stars + + Returns + ------- + colorCatalog : `lsst.afw.table.SimpleCatalog` + Catalog with star magnitudes. + """ + bands = ["g", "r", "i", "z", "y"] + nStars = len(starRas) + + # Make a catalog following what is done in `fgcmCal`. + schema = afwTable.SimpleTable.makeMinimalSchema() + schema.addField( + "mag_std_noabs", + type="ArrayF", + doc="Standard magnitude (no absolute calibration)", + size=len(bands), + ) + colorCatalog = afwTable.SimpleCatalog(schema) + colorCatalog.resize(len(starRas)) + colorCatalog["coord_ra"] = (starRas * u.degree + 10 * np.random.randn(nStars) * u.mas).to(u.radian) + colorCatalog["coord_dec"] = (starDecs * u.degree + 10 * np.random.randn(nStars) * u.mas).to(u.radian) + + magMin = 19 + magMax = 23 + for i in range(len(bands)): + colorCatalog["mag_std_noabs"][:, i] = np.random.random(nStars) * (magMax - magMin) + magMin + + md = PropertyList() + md.set("BANDS", bands) + colorCatalog.setMetadata(md) + return colorCatalog + def test_get_exposure_info(self): """Test that information for input exposures is as expected and that the WCS in the class object gives approximately the same results as the @@ -687,6 +731,21 @@ def test_inputCameraModel(self): self.assertAlmostEqual(np.mean(dDec), 0) self.assertAlmostEqual(np.std(dDec), 0) + def test_useColor(self): + """Test running task with color catalog and DCR fitting.""" + config = copy(self.config) + config.useColor = True + + task = GbdesAstrometricFitTask(config=config) + outputs = task.run( + self.inputCatalogRefs, + self.inputVisitSummary, + instrumentName=self.instrumentName, + refObjectLoader=self.refObjectLoader, + colorCatalog=self.colorCatalog, + ) + self.assertEqual(len(outputs.colorParams["visit"]), len(self.testVisits)) + class TestGbdesGlobalAstrometricFit(TestGbdesAstrometricFit): @classmethod @@ -795,6 +854,8 @@ def setUpClass(cls): allStarIds, allStarRAs, allStarDecs, cls.trueWCSs, inScienceFraction ) + cls.colorCatalog = cls._make_colors(np.concatenate(allStarRAs), np.concatenate(allStarDecs)) + cls.outputs = cls.task.run( cls.inputVisitSummary, cls.isolatedStarSources, @@ -1057,6 +1118,23 @@ def test_inputCameraModel(self): self.assertAlmostEqual(np.mean(dDec), 0, places=6) self.assertAlmostEqual(np.std(dDec), 0) + def test_useColor(self): + """Test running task with color catalog and DCR fitting.""" + config = copy(self.config) + config.useColor = True + + task = GbdesGlobalAstrometricFitTask(config=config) + + outputs = task.run( + self.inputVisitSummary, + self.isolatedStarSources, + self.isolatedStarCatalogs, + instrumentName=self.instrumentName, + refObjectLoader=self.refObjectLoader, + colorCatalog=self.colorCatalog, + ) + self.assertEqual(len(outputs.colorParams["visit"]), len(self.testVisits)) + def setup_module(module): lsst.utils.tests.init()