Skip to content

Commit

Permalink
Merge pull request #37 from lsst/tickets/DM-41635
Browse files Browse the repository at this point in the history
DM-41635: Add option to output model parameters and cov
  • Loading branch information
cmsaunders authored Dec 13, 2023
2 parents cc1f95e + d09b7f4 commit b73e2dd
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
73 changes: 72 additions & 1 deletion python/lsst/drp/tasks/gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,22 @@ class GbdesAstrometricFitConnections(
storageClass="ArrowNumpyDict",
dimensions=("instrument", "skymap", "tract", "physical_filter"),
)
modelParams = pipeBase.connectionTypes.Output(
doc="WCS parameter covariance.",
name="gbdesAstrometricFit_modelParams",
storageClass="ArrowNumpyDict",
dimensions=("instrument", "skymap", "tract", "physical_filter"),
)

def getSpatialBoundsConnections(self):
return ("inputVisitSummaries",)

def __init__(self, *, config=None):
super().__init__(config=config)

if not self.config.saveModelParams:
self.outputs.remove("modelParams")


class GbdesAstrometricFitConfig(
pipeBase.PipelineTaskConfig, pipelineConnections=GbdesAstrometricFitConnections
Expand Down Expand Up @@ -343,6 +355,14 @@ class GbdesAstrometricFitConfig(
doc="Set the random seed for selecting data points to reserve from the fit for validation.",
default=1234,
)
saveModelParams = pexConfig.Field(
dtype=bool,
doc=(
"Save the parameters and covariance of the WCS model. Default to "
"false because this can be very large."
),
default=False,
)

def setDefaults(self):
# Use only stars because aperture fluxes of galaxies are biased and
Expand Down Expand Up @@ -445,6 +465,8 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
butlerQC.put(outputWcs, wcsOutputRefDict[visit])
butlerQC.put(output.outputCatalog, outputRefs.outputCatalog)
butlerQC.put(output.starCatalog, outputRefs.starCatalog)
if self.config.saveModelParams:
butlerQC.put(output.modelParams, outputRefs.modelParams)

def run(
self, inputCatalogRefs, inputVisitSummaries, instrumentName="", refEpoch=None, refObjectLoader=None
Expand Down Expand Up @@ -571,9 +593,14 @@ def run(
outputWCSs = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo)
outputCatalog = wcsf.getOutputCatalog()
starCatalog = wcsf.getStarCatalog()
modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None

return pipeBase.Struct(
outputWCSs=outputWCSs, fitModel=wcsf, outputCatalog=outputCatalog, starCatalog=starCatalog
outputWCSs=outputWCSs,
fitModel=wcsf,
outputCatalog=outputCatalog,
starCatalog=starCatalog,
modelParams=modelParams,
)

def _prep_sky(self, inputVisitSummaries, epoch, fieldName="Field"):
Expand Down Expand Up @@ -1336,3 +1363,47 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo):
catalogs[visit] = catalog

return catalogs

def _compute_model_params(self, wcsf):
"""Get the WCS model parameters and covariance and convert to a
dictionary that will be readable as a pandas dataframe or other table.
Parameters
----------
wcsf : `wcsfit.WCSFit`
WCSFit object, assumed to have fit model.
Returns
-------
modelParams : `dict`
Parameters and covariance of the best-fit WCS model.
"""
modelParamDict = wcsf.mapCollection.getParamDict()
modelCovariance = wcsf.getModelCovariance()

modelParams = {k: [] for k in ["mapName", "coordinate", "parameter", "coefficientNumber"]}
i = 0
for mapName, params in modelParamDict.items():
nCoeffs = len(params)
# There are an equal number of x and y coordinate parameters
nCoordCoeffs = nCoeffs // 2
modelParams["mapName"].extend([mapName] * nCoeffs)
modelParams["coordinate"].extend(["x"] * nCoordCoeffs)
modelParams["coordinate"].extend(["y"] * nCoordCoeffs)
modelParams["parameter"].extend(params)
modelParams["coefficientNumber"].extend(np.arange(nCoordCoeffs))
modelParams["coefficientNumber"].extend(np.arange(nCoordCoeffs))

for p in range(nCoeffs):
if p < nCoordCoeffs:
coord = "x"
else:
coord = "y"
modelParams[f"{mapName}_{coord}_{p}_cov"] = modelCovariance[i]
i += 1

# Convert the dictionary values from lists to numpy arrays.
for key, value in modelParams.items():
modelParams[key] = np.array(value)

return modelParams
11 changes: 11 additions & 0 deletions tests/test_gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def setUpClass(cls):
cls.config.exposurePolyOrder = 6
cls.config.fitReserveFraction = 0
cls.config.fitReserveRandomSeed = 1234
cls.config.saveModelParams = True
cls.task = GbdesAstrometricFitTask(config=cls.config)

cls.exposureInfo, cls.exposuresHelper, cls.extensionInfo = cls.task._get_exposure_info(
Expand Down Expand Up @@ -491,6 +492,16 @@ def test_make_outputs(self):
self.assertAlmostEqual(np.mean(dDec), 0)
self.assertAlmostEqual(np.std(dDec), 0)

def test_compute_model_params(self):
"""Test the optional model parameters and covariance output."""
modelParams = pd.DataFrame(self.outputs.modelParams)
# Check that DataFrame is the expected size.
shape = modelParams.shape
self.assertEqual(shape[0] + 4, shape[1])
# Check that covariance matrix is symmetric.
covariance = (modelParams.iloc[:, 4:]).to_numpy()
np.testing.assert_allclose(covariance, covariance.T, atol=1e-18)

def test_run(self):
"""Test that run method recovers the input model parameters"""
outputMaps = self.outputs.fitModel.mapCollection.getParamDict()
Expand Down

0 comments on commit b73e2dd

Please sign in to comment.