From 7fc0c7d23560c3cbb188d6c9e684d17268514389 Mon Sep 17 00:00:00 2001 From: Dan Taranu Date: Mon, 16 Sep 2024 10:20:20 -0700 Subject: [PATCH] Replace pandas.DataFrame usage with astropy.Table --- .../pipe/tasks/diff_matched_tract_catalog.py | 151 +++++++++++------- python/lsst/pipe/tasks/match_tract_catalog.py | 26 +-- .../match_tract_catalog_probabilistic.py | 9 +- tests/test_diff_matched_tract_catalog.py | 56 ++++--- 4 files changed, 145 insertions(+), 97 deletions(-) diff --git a/python/lsst/pipe/tasks/diff_matched_tract_catalog.py b/python/lsst/pipe/tasks/diff_matched_tract_catalog.py index daee70e99..9dba26ce1 100644 --- a/python/lsst/pipe/tasks/diff_matched_tract_catalog.py +++ b/python/lsst/pipe/tasks/diff_matched_tract_catalog.py @@ -37,6 +37,7 @@ from abc import ABCMeta, abstractmethod from astropy.stats import mad_std +import astropy.table import astropy.units as u from dataclasses import dataclass from decimal import Decimal @@ -48,6 +49,7 @@ from smatch.matcher import sphdist from types import SimpleNamespace from typing import Sequence +import warnings def is_sequence_set(x: Sequence): @@ -75,14 +77,14 @@ class DiffMatchedTractCatalogConnections( cat_ref = cT.Input( doc="Reference object catalog to match from", name="{name_input_cat_ref}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), deferLoad=True, ) cat_target = cT.Input( doc="Target object catalog to match", name="{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), deferLoad=True, ) @@ -95,33 +97,33 @@ class DiffMatchedTractCatalogConnections( cat_match_ref = cT.Input( doc="Reference match catalog with indices of target matches", name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), deferLoad=True, ) cat_match_target = cT.Input( doc="Target match catalog with indices of references matches", name="match_target_{name_input_cat_ref}_{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), deferLoad=True, ) columns_match_target = cT.Input( doc="Target match catalog columns", name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns", - storageClass="DataFrameIndex", + storageClass="ArrowColumnList", dimensions=("tract", "skymap"), ) cat_matched = cT.Output( doc="Catalog with reference and target columns for joined sources", name="matched_{name_input_cat_ref}_{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), ) diff_matched = cT.Output( doc="Table with aggregated counts, difference and chi statistics", name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), ) @@ -137,6 +139,8 @@ def __init__(self, *, config=None): dimensions=(), deferLoad=old.deferLoad, ) + if not (config.compute_stats and len(config.columns_flux) > 0): + del self.diff_matched class MatchedCatalogFluxesConfig(pexConfig.Config): @@ -685,10 +689,10 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): def run( self, - catalog_ref: pd.DataFrame, - catalog_target: pd.DataFrame, - catalog_match_ref: pd.DataFrame, - catalog_match_target: pd.DataFrame, + catalog_ref: pd.DataFrame | astropy.table.Table, + catalog_target: pd.DataFrame | astropy.table.Table, + catalog_match_ref: pd.DataFrame | astropy.table.Table, + catalog_match_target: pd.DataFrame | astropy.table.Table, wcs: afwGeom.SkyWcs = None, ) -> pipeBase.Struct: """Load matched reference and target (measured) catalogs, measure summary statistics, and output @@ -696,14 +700,14 @@ def run( Parameters ---------- - catalog_ref : `pandas.DataFrame` + catalog_ref : `pandas.DataFrame` | `astropy.table.Table` A reference catalog to diff objects/sources from. - catalog_target : `pandas.DataFrame` + catalog_target : `pandas.DataFrame` | `astropy.table.Table` A target catalog to diff reference objects/sources to. - catalog_match_ref : `pandas.DataFrame` + catalog_match_ref : `pandas.DataFrame` | `astropy.table.Table` A catalog with match indices of target sources and selection flags for each reference source. - catalog_match_target : `pandas.DataFrame` + catalog_match_target : `pandas.DataFrame` | `astropy.table.Table` A catalog with selection flags for each target source. wcs : `lsst.afw.image.SkyWcs` A coordinate system to convert catalog positions to sky coordinates, @@ -718,16 +722,33 @@ def run( # Would be nice if this could refer directly to ConfigClass config: DiffMatchedTractCatalogConfig = self.config - select_ref = catalog_match_ref['match_candidate'].values + is_ref_pd = isinstance(catalog_ref, pd.DataFrame) + is_target_pd = isinstance(catalog_target, pd.DataFrame) + is_match_ref_pd = isinstance(catalog_match_ref, pd.DataFrame) + is_match_target_pd = isinstance(catalog_match_target, pd.DataFrame) + if is_ref_pd: + catalog_ref = astropy.table.Table.from_pandas(catalog_ref) + if is_target_pd: + catalog_target = astropy.table.Table.from_pandas(catalog_target) + if is_match_ref_pd: + catalog_match_ref = astropy.table.Table.from_pandas(catalog_match_ref) + if is_match_target_pd: + catalog_match_target = astropy.table.Table.from_pandas(catalog_match_target) + # TODO: Remove pandas support in DM-46523 + if is_ref_pd or is_target_pd or is_match_ref_pd or is_match_target_pd: + warnings.warn("pandas usage in MatchProbabilisticTask is deprecated; it will be removed " + " in favour of astropy.table after release 28.0.0", category=FutureWarning) + + select_ref = catalog_match_ref['match_candidate'] # Add additional selection criteria for target sources beyond those for matching # (not recommended, but can be done anyway) - select_target = (catalog_match_target['match_candidate'].values + select_target = (catalog_match_target['match_candidate'] if 'match_candidate' in catalog_match_target.columns else np.ones(len(catalog_match_target), dtype=bool)) for column in config.columns_target_select_true: - select_target &= catalog_target[column].values + select_target &= catalog_target[column] for column in config.columns_target_select_false: - select_target &= ~catalog_target[column].values + select_target &= ~catalog_target[column] ref, target = config.coord_format.format_catalogs( catalog_ref=catalog_ref, catalog_target=catalog_target, @@ -739,9 +760,9 @@ def run( if config.include_unmatched: for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)): - cat_add['match_candidate'] = cat_match['match_candidate'].values + cat_add['match_candidate'] = cat_match['match_candidate'] - match_row = catalog_match_ref['match_row'].values + match_row = catalog_match_ref['match_row'] matched_ref = match_row >= 0 matched_row = match_row[matched_ref] matched_target = np.zeros(n_target, dtype=bool) @@ -761,48 +782,44 @@ def run( ) if config.coord_format.coords_spherical else np.hypot( target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2, ) + cat_target_matched = cat_target[matched_row] + # This will convert a masked array to an array filled with nans + # wherever there are bad values (otherwise sphdist can raise) + c1_err, c2_err = ( + np.ma.getdata(cat_target_matched[c_err]) for c_err in (coord1_target_err, coord2_target_err) + ) # Should probably explicitly add cosine terms if ref has errors too dist_err[matched_row] = sphdist( - target_match_c1, target_match_c2, - target_match_c1 + cat_target.iloc[matched_row][coord1_target_err].values, - target_match_c2 + cat_target.iloc[matched_row][coord2_target_err].values, - ) if config.coord_format.coords_spherical else np.hypot( - cat_target.iloc[matched_row][coord1_target_err].values, - cat_target.iloc[matched_row][coord2_target_err].values - ) + target_match_c1, target_match_c2, target_match_c1 + c1_err, target_match_c2 + c2_err + ) if config.coord_format.coords_spherical else np.hypot(c1_err, c2_err) cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err # Create a matched table, preserving the target catalog's named index (if it has one) - cat_left = cat_target.iloc[matched_row] - has_index_left = cat_left.index.name is not None - cat_right = cat_ref[matched_ref].reset_index() - cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns] - cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1) + cat_left = cat_target[matched_row] + cat_right = cat_ref[matched_ref] + cat_right.rename_columns( + list(cat_right.columns), + new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns], + ) + cat_matched = astropy.table.hstack((cat_left, cat_right)) if config.include_unmatched: # Create an unmatched table with the same schema as the matched one # ... but only for objects with no matches (for completeness/purity) # and that were selected for matching (or inclusion via config) - cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False) - cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns) - match_row_target = catalog_match_target['match_row'].values - cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index( - drop=not has_index_left) + cat_right = astropy.table.Table( + cat_ref[~matched_ref & select_ref] + ) + cat_right.rename_columns( + cat_right.colnames, + [f"{config.column_matched_prefix_ref}{col}" for col in cat_right.colnames], + ) + match_row_target = catalog_match_target['match_row'] + cat_left = cat_target[~(match_row_target >= 0) & select_target] + # This may be slower than pandas but will, for example, create + # masked columns for booleans, which pandas does not support. # See https://github.com/pandas-dev/pandas/issues/46662 - # astropy masked columns would handle this much more gracefully - # Unfortunately, that would require storageClass migration - # So we use pandas "extended" nullable types for now - for cat_i in (cat_left, cat_right): - for colname in cat_i.columns: - column = cat_i[colname] - dtype = str(column.dtype) - if dtype == "bool": - cat_i[colname] = column.astype("boolean") - elif dtype.startswith("int"): - cat_i[colname] = column.astype(f"Int{dtype[3:]}") - elif dtype.startswith("uint"): - cat_i[colname] = column.astype(f"UInt{dtype[3:]}") - cat_unmatched = pd.concat(objs=(cat_left, cat_right)) + cat_unmatched = astropy.table.vstack([cat_left, cat_right]) for columns_convert_base, prefix in ( (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref), @@ -812,8 +829,14 @@ def run( columns_convert = { f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items() } if prefix else columns_convert_base - for cat_convert in (cat_matched, cat_unmatched): - cat_convert.rename(columns=columns_convert, inplace=True) + to_convert = [cat_matched] + if config.include_unmatched: + to_convert.append(cat_unmatched) + for cat_convert in to_convert: + cat_convert.rename_columns( + tuple(columns_convert.keys()), + tuple(columns_convert.values()), + ) for column_flux in columns_convert.values(): cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux]) @@ -822,7 +845,8 @@ def run( n_bands = len(band_fluxes) # TODO: Deprecated by RFC-1017 and to be removed in DM-44988 - if self.config.compute_stats and (n_bands > 0): + do_stats = self.config.compute_stats and (n_bands > 0) + if do_stats: # Slightly smelly hack for when a column (like distance) is already relative to truth column_dummy = 'dummy' cat_ref[column_dummy] = np.zeros_like(ref.coord1) @@ -831,7 +855,7 @@ def run( # TODO: remove the assumption of a boolean column extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted) - extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut + extended_target = cat_target[config.column_target_extended] >= config.extendedness_cut # Define difference/chi columns and statistics thereof suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'} @@ -999,7 +1023,7 @@ def run( if n_match > 0: rows_matched = match_row_bin[match_good] - subset_target = cat_target.iloc[rows_matched] + subset_target = cat_target[rows_matched] if (is_extended is not None) and (idx_model == 0): right_type = extended_target[rows_matched] == is_extended n_total = len(right_type) @@ -1016,15 +1040,15 @@ def run( # compute stats for this bin, for all columns for column, (column_ref, column_target, column_err_target, skip_diff) \ in columns_target.items(): - values_ref = cat_ref[column_ref][match_good].values + values_ref = cat_ref[column_ref][match_good] errors_target = ( - subset_target[column_err_target].values + subset_target[column_err_target] if column_err_target is not None else None ) compute_stats( values_ref, - subset_target[column_target].values, + subset_target[column_target], errors_target, row, stats, @@ -1066,7 +1090,10 @@ def run( mag_ref_first = mag_ref if config.include_unmatched: - cat_matched = pd.concat((cat_matched, cat_unmatched)) + # This is probably less efficient than just doing an outer join originally; worth checking + cat_matched = astropy.table.vstack([cat_matched, cat_unmatched]) - retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data)) + retStruct = pipeBase.Struct(cat_matched=cat_matched) + if do_stats: + retStruct.diff_matched = astropy.table.Table(data) return retStruct diff --git a/python/lsst/pipe/tasks/match_tract_catalog.py b/python/lsst/pipe/tasks/match_tract_catalog.py index 249ba1b10..51aba146e 100644 --- a/python/lsst/pipe/tasks/match_tract_catalog.py +++ b/python/lsst/pipe/tasks/match_tract_catalog.py @@ -32,6 +32,7 @@ from abc import ABC, abstractmethod +import astropy.table import pandas as pd from typing import Tuple, Set @@ -50,14 +51,14 @@ class MatchTractCatalogConnections( cat_ref = cT.Input( doc="Reference object catalog to match from", name="{name_input_cat_ref}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), deferLoad=True, ) cat_target = cT.Input( doc="Target object catalog to match", name="{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), deferLoad=True, ) @@ -67,17 +68,16 @@ class MatchTractCatalogConnections( storageClass="SkyMap", dimensions=("skymap",), ) - # TODO: Change outputs to ArrowAstropy in DM-44159 cat_output_ref = cT.Output( doc="Reference matched catalog with indices of target matches", name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), ) cat_output_target = cT.Output( doc="Target matched catalog with indices of reference matches", name="match_target_{name_input_cat_ref}_{name_input_cat_target}", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("tract", "skymap"), ) @@ -128,17 +128,17 @@ def __init__(self, **kwargs): @abstractmethod def run( self, - catalog_ref: pd.DataFrame, - catalog_target: pd.DataFrame, + catalog_ref: pd.DataFrame | astropy.table.Table, + catalog_target: pd.DataFrame | astropy.table.Table, wcs: afwGeom.SkyWcs = None, ) -> pipeBase.Struct: """Match sources in a reference tract catalog with a target catalog. Parameters ---------- - catalog_ref : `pandas.DataFrame` + catalog_ref : `pandas.DataFrame` | `astropy.table.Table` A reference catalog to match objects/sources from. - catalog_target : `pandas.DataFrame` + catalog_target : `pandas.DataFrame` | `astropy.table.Table` A target catalog to match reference objects/sources to. wcs : `lsst.afw.image.SkyWcs` A coordinate system to convert catalog positions to sky coordinates. @@ -211,17 +211,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): def run( self, - catalog_ref: pd.DataFrame, - catalog_target: pd.DataFrame, + catalog_ref: pd.DataFrame | astropy.table.Table, + catalog_target: pd.DataFrame | astropy.table.Table, wcs: afwGeom.SkyWcs = None, ) -> pipeBase.Struct: """Match sources in a reference tract catalog with a target catalog. Parameters ---------- - catalog_ref : `pandas.DataFrame` + catalog_ref : `pandas.DataFrame` | `astropy.table.Table` A reference catalog to match objects/sources from. - catalog_target : `pandas.DataFrame` + catalog_target : `pandas.DataFrame` | `astropy.table.Table` A target catalog to match reference objects/sources to. wcs : `lsst.afw.image.SkyWcs` A coordinate system to convert catalog positions to sky coordinates, diff --git a/python/lsst/pipe/tasks/match_tract_catalog_probabilistic.py b/python/lsst/pipe/tasks/match_tract_catalog_probabilistic.py index 525276536..bb080f563 100644 --- a/python/lsst/pipe/tasks/match_tract_catalog_probabilistic.py +++ b/python/lsst/pipe/tasks/match_tract_catalog_probabilistic.py @@ -30,6 +30,7 @@ from .match_tract_catalog import MatchTractCatalogSubConfig, MatchTractCatalogSubTask +import astropy.table import pandas as pd from typing import Set @@ -65,17 +66,17 @@ def __init__(self, **kwargs): def run( self, - catalog_ref: pd.DataFrame, - catalog_target: pd.DataFrame, + catalog_ref: pd.DataFrame | astropy.table.Table, + catalog_target: pd.DataFrame | astropy.table.Table, wcs: afwGeom.SkyWcs = None, ) -> pipeBase.Struct: """Match sources in a reference tract catalog with a target catalog. Parameters ---------- - catalog_ref : `pandas.DataFrame` + catalog_ref : `pandas.DataFrame` | `astropy.table.Table` A reference catalog to match objects/sources from. - catalog_target : `pandas.DataFrame` + catalog_target : `pandas.DataFrame` | `astropy.table.Table` A target catalog to match reference objects/sources to. wcs : `lsst.afw.image.SkyWcs` A coordinate system to convert catalog positions to sky coordinates. diff --git a/tests/test_diff_matched_tract_catalog.py b/tests/test_diff_matched_tract_catalog.py index 432f1b919..f985bb5f3 100644 --- a/tests/test_diff_matched_tract_catalog.py +++ b/tests/test_diff_matched_tract_catalog.py @@ -24,15 +24,15 @@ import lsst.utils.tests import lsst.afw.geom as afwGeom -import pytest from lsst.meas.astrom import ConvertCatalogCoordinatesConfig from lsst.pipe.tasks.diff_matched_tract_catalog import ( DiffMatchedTractCatalogConfig, DiffMatchedTractCatalogTask, MatchedCatalogFluxesConfig, ) +from astropy.table import Table import numpy as np import os -import pandas as pd +import pytest def _error_format(column): @@ -94,7 +94,7 @@ def setUp(self): columns_flux[1]: fluxes[1][idx_ref], DiffMatchedTractCatalogConfig.column_ref_extended.default: extended_ref, } - self.catalog_ref = pd.DataFrame(data=data_ref) + self.catalog_ref = Table(data=data_ref) data_target = { column_ra_target: ra + eps_coord, @@ -109,16 +109,16 @@ def setUp(self): DiffMatchedTractCatalogConfig.columns_target_select_false.default[0]: ~flags, DiffMatchedTractCatalogConfig.column_target_extended.default: extended_target, } - self.catalog_target = pd.DataFrame(data=data_target) + self.catalog_target = Table(data=data_target) # Make the last two rows unmatched (we set eps_coord very large) match_row = np.arange(len(ra))[::-1] - n_unmatched - self.catalog_match_ref = pd.DataFrame(data={ + self.catalog_match_ref = Table(data={ 'match_candidate': flags, 'match_row': match_row, }) - self.catalog_match_target = pd.DataFrame(data={ + self.catalog_match_target = Table(data={ 'match_candidate': flags, 'match_row': match_row, }) @@ -159,30 +159,50 @@ def test_DiffMatchedTractCatalogTask(self): # These tables will have columns added to them in run columns_ref, columns_target = (list(x.columns) for x in (self.catalog_ref, self.catalog_target)) task = DiffMatchedTractCatalogTask(config=self.config_stats) + result = task.run( + catalog_ref=self.catalog_ref, + catalog_target=self.catalog_target, + catalog_match_ref=self.catalog_match_ref, + catalog_match_target=self.catalog_match_target, + wcs=self.wcs, + ) + # TODO: Remove pandas support in DM-46523 + with pytest.warns(FutureWarning): + result_pd = task.run( + catalog_ref=self.catalog_ref.to_pandas(), + catalog_target=self.catalog_target.to_pandas(), + catalog_match_ref=self.catalog_match_ref.to_pandas(), + catalog_match_target=self.catalog_match_target.to_pandas(), + wcs=self.wcs, + ) + self.assertListEqual(list(result.cat_matched.columns), list(result_pd.cat_matched.columns)) + for column in result.cat_matched.columns: + self.assertListEqual(list(result.cat_matched[column]), list(result_pd.cat_matched[column])) + # TODO: Remove diff_matched support in DM-44988 with pytest.warns(FutureWarning): task.config.compute_stats = True - result = task.run( + result_stats = task.run( catalog_ref=self.catalog_ref, catalog_target=self.catalog_target, catalog_match_ref=self.catalog_match_ref, catalog_match_target=self.catalog_match_target, wcs=self.wcs, ) + self.assertGreater(len(result_stats.diff_matched), 0) + row = np.array([float(x) for x in result_stats.diff_matched[0].values()]) + # Run to re-save reference data. Will be loaded after this test completes. + resave = False + if resave: + np.savetxt(filename_diff_matched, row) + + self.assertEqual(len(row), len(self.diff_matched)) + self.assertFloatsAlmostEqual(row, self.diff_matched, atol=1e-8, rtol=1e-8) + columns_result = list(result.cat_matched.columns) columns_expect = list(columns_target) + ["match_distance", "match_distanceErr"] prefix = DiffMatchedTractCatalogConfig.column_matched_prefix_ref.default - columns_expect.append(f'{prefix}index') columns_expect.extend((f'{prefix}{col}' for col in columns_ref)) - self.assertEqual(columns_expect, columns_result) - - row = result.diff_matched.iloc[0].values.astype(float) - # Run to re-save reference data. Will be loaded after this test completes. - resave = False - if resave: - np.savetxt(filename_diff_matched, row) - - self.assertEqual(len(row), len(self.diff_matched)) - self.assertFloatsAlmostEqual(row, self.diff_matched, atol=1e-8, rtol=1e-8) + self.assertListEqual(columns_expect, columns_result) def test_spherical(self): task = DiffMatchedTractCatalogTask(config=self.config)