Skip to content

Commit

Permalink
Recover wcs fix, plus other review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cmsaunders committed Feb 13, 2024
1 parent ef70b1d commit a107c60
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 34 deletions.
72 changes: 39 additions & 33 deletions python/lsst/drp/tasks/gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,11 @@ def _make_ref_covariance_matrix(
for i, pi in enumerate(positionParameters):
for j, pj in enumerate(positionParameters):
if i == j:
cov[:, k] = (
((refCat[f"{pi}Err"].value) ** 2 * inputUnit**2).to(units[j] * units[j]).value
)
cov[:, k] = ((refCat[f"{pi}Err"].value) ** 2 * inputUnit**2).to(units[j] * units[j]).value
elif i > j:
cov[:, k] = (refCat[f"{pj}_{pi}_Cov"].value * inputUnit**2).to_value(
units[i] * units[j]
)
cov[:, k] = (refCat[f"{pj}_{pi}_Cov"].value * inputUnit**2).to_value(units[i] * units[j])
else:
cov[:, k] = (refCat[f"{pi}_{pj}_Cov"].value * inputUnit**2).to_value(
units[i] * units[j]
)

cov[:, k] = (refCat[f"{pi}_{pj}_Cov"].value * inputUnit**2).to_value(units[i] * units[j])
k += 1
return cov

Expand Down Expand Up @@ -306,6 +299,17 @@ class GbdesAstrometricFitConfig(
target=ReferenceSourceSelectorTask,
doc="How to down-select the loaded astrometry reference catalog.",
)
referenceFilter = pexConfig.Field(
dtype=str,
doc="Name of filter to load from reference catalog. This is a required argument, although the values"
"returned are not used.",
default="phot_g_mean",
)
applyRefCatProperMotion = pexConfig.Field(
dtype=bool,
doc="Apply proper motion to shift reference catalog to epoch of observations.",
default=True,
)
matchRadius = pexConfig.Field(
doc="Matching tolerance between associated objects (arcseconds).", dtype=float, default=1.0
)
Expand Down Expand Up @@ -461,22 +465,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputRefCats = np.array([inputRefCat.dataId["htm7"] for inputRefCat in inputs["referenceCatalog"]])
inputs["referenceCatalog"] = [inputs["referenceCatalog"][v] for v in inputRefCats.argsort()]

sampleRefCat = inputs["referenceCatalog"][0].get()
refEpoch = sampleRefCat[0]["epoch"]

refConfig = LoadReferenceObjectsConfig()
refConfig.anyFilterMapsToThis = "phot_g_mean"
refConfig.requireProperMotion = True
if self.config.applyRefCatProperMotion:
refConfig.requireProperMotion = True
refObjectLoader = ReferenceObjectLoader(
dataIds=[ref.datasetRef.dataId for ref in inputRefCatRefs],
refCats=inputs.pop("referenceCatalog"),
config=refConfig,
log=self.log,
)

output = self.run(
**inputs, instrumentName=instrumentName, refEpoch=refEpoch, refObjectLoader=refObjectLoader
)
output = self.run(**inputs, instrumentName=instrumentName, refObjectLoader=refObjectLoader)

wcsOutputRefDict = {outWcsRef.dataId["visit"]: outWcsRef for outWcsRef in outputRefs.outputWcs}
for visit, outputWcs in output.outputWcss.items():
Expand Down Expand Up @@ -940,13 +939,17 @@ def _load_refcat(
refCovariance : `list` [`float`]
Flattened output covariance matrix.
"""
formattedEpoch = astropy.time.Time(epoch, format="mjd")
if self.config.applyRefCatProperMotion:
formattedEpoch = astropy.time.Time(epoch, format="mjd")
else:
formattedEpoch = None

refFilter = refObjectLoader.config.anyFilterMapsToThis
if region is not None:
skyRegion = refObjectLoader.loadRegion(region, refFilter, epoch=formattedEpoch)
skyRegion = refObjectLoader.loadRegion(region, self.config.referenceFilter, epoch=formattedEpoch)
elif (center is not None) and (radius is not None):
skyRegion = refObjectLoader.loadSkyCircle(center, radius, refFilter, epoch=formattedEpoch)
skyRegion = refObjectLoader.loadSkyCircle(
center, radius, self.config.referenceFilter, epoch=formattedEpoch
)
else:
raise RuntimeError("Either `region` or `center` and `radius` must be set.")

Expand Down Expand Up @@ -1503,6 +1506,11 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None
# Set up the schema for the output catalogs
schema = lsst.afw.table.ExposureTable.makeMinimalSchema()
schema.addField("visit", type="L", doc="Visit number")
schema.addField(
"recoveredWcs",
type="Flag",
doc="Input WCS missing, output recovered from other input visit/detectors.",
)

# Pixels will need to be rescaled before going into the mappings
sampleDetector = visitSummaryTables[0][0]
Expand All @@ -1523,10 +1531,11 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None
catalog.resize(len(exposureInfo.detectors))
catalog["visit"] = visit

for d, detector in enumerate(visitSummary["id"]):
for d, detector in enumerate(exposureInfo.detectors):
mapName = f"{visit}/{detector}"
if mapName in wcsf.mapCollection.allMapNames():
mapElements = wcsf.mapCollection.orderAtoms(f"{mapName}/base")
catalog[d]["recoveredWcs"] = False
else:
# This extension was not fit, but the WCS can be recovered
# using the maps fit from sources on other visits but the
Expand All @@ -1548,6 +1557,7 @@ def _make_outputs(self, wcsf, visitSummaryTables, exposureInfo, mapTemplate=None
element = element.replace("EXPOSURE", str(visit))
element = element.replace("DEVICE", str(detector))
mapElements.append(element)
catalog[d]["recoveredWcs"] = True
mapDict = {}
for m, mapElement in enumerate(mapElements):
mapType = wcsf.mapCollection.getMapType(mapElement)
Expand Down Expand Up @@ -1769,22 +1779,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
[inputs["isolatedStarCatalogs"][t] for t in inputIsolatedStarSourceTracts.argsort()]
)

sampleRefCat = inputs["referenceCatalog"][0].get()
refEpoch = sampleRefCat[0]["epoch"]

refConfig = LoadReferenceObjectsConfig()
refConfig.anyFilterMapsToThis = "phot_g_mean"
refConfig.requireProperMotion = True
if self.config.applyRefCatProperMotion:
refConfig.requireProperMotion = True
refObjectLoader = ReferenceObjectLoader(
dataIds=[ref.datasetRef.dataId for ref in inputRefCatRefs],
refCats=inputs.pop("referenceCatalog"),
config=refConfig,
log=self.log,
)

output = self.run(
**inputs, instrumentName=instrumentName, refEpoch=refEpoch, refObjectLoader=refObjectLoader
)
output = self.run(**inputs, instrumentName=instrumentName, refObjectLoader=refObjectLoader)

for outputRef in outputRefs.outputWcs:
visit = outputRef.dataId["visit"]
Expand Down Expand Up @@ -1866,7 +1871,7 @@ def run(
self.log.info("Fit the WCSs")
# Set up a YAML-type string using the config variables and a sample
# visit
inputYaml = self.make_yaml(inputVisitSummaries[0])
inputYaml, mapTemplate = self.make_yaml(inputVisitSummaries[0])

# Set the verbosity level for WCSFit from the task log level.
# TODO: DM-36850, Add lsst.log to gbdes so that log messages are
Expand Down Expand Up @@ -1911,7 +1916,7 @@ def run(
)
self.log.info("WCS fitting done")

outputWcss = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo)
outputWcss = self._make_outputs(wcsf, inputVisitSummaries, exposureInfo, mapTemplate=mapTemplate)
outputCatalog = wcsf.getOutputCatalog()
starCatalog = wcsf.getStarCatalog()
modelParams = self._compute_model_params(wcsf) if self.config.saveModelParams else None
Expand Down Expand Up @@ -1955,6 +1960,7 @@ def _prep_sky(self, inputVisitSummaries):
detectorCorners = [
lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees).getVector()
for (ra, dec) in zip(visSum["raCorners"].ravel(), visSum["decCorners"].ravel())
if (np.isfinite(ra) and (np.isfinite(dec)))
]
allDetectorCorners.append(detectorCorners)

Expand Down
28 changes: 27 additions & 1 deletion tests/test_gbdesAstrometricFit.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def test_missingWcs(self):

# Check that the fit WCS for the extension with input WCS=None returns
# finite sky values.
testWcs = outputs.outputWCSs[self.testVisits[testVisit]][testDetector].getWcs()
testWcs = outputs.outputWcss[self.testVisits[testVisit]][testDetector].getWcs()
testSky = testWcs.pixelToSky(0, 0)
self.assertTrue(testSky.isFinite())

Expand Down Expand Up @@ -941,6 +941,32 @@ def test_make_outputs(self):
self.assertAlmostEqual(np.mean(dDec), 0)
self.assertAlmostEqual(np.std(dDec), 0)

def test_missingWcs(self):
"""Test that task does not fail when the input WCS is None for one
extension and that the fit WCS for that extension returns a finite
result.
"""
inputVisitSummary = self.inputVisitSummary.copy()
# Set one WCS to be None
testVisit = 0
testDetector = 20
inputVisitSummary[testVisit][testDetector].setWcs(None)

outputs = self.task.run(
inputVisitSummary,
self.isolatedStarSources,
self.isolatedStarCatalogs,
instrumentName=self.instrumentName,
refEpoch=self.refEpoch,
refObjectLoader=self.refObjectLoader,
)

# Check that the fit WCS for the extension with input WCS=None returns
# finite sky values.
testWcs = outputs.outputWcss[self.testVisits[testVisit]][testDetector].getWcs()
testSky = testWcs.pixelToSky(0, 0)
self.assertTrue(testSky.isFinite())


def setup_module(module):
lsst.utils.tests.init()
Expand Down

0 comments on commit a107c60

Please sign in to comment.