Skip to content

Commit

Permalink
Add memory-efficient stack for consolidate tasks.
Browse files Browse the repository at this point in the history
Using this for MakeCcdVisitTableTask is trickier because we can't
ask an ExposureCatalog how many rows it has before loading it.  If
that's needed, we can extend this code to do it on another branch.
  • Loading branch information
TallJimbo committed Dec 13, 2024
1 parent 77c6147 commit b9a7f93
Showing 1 changed file with 95 additions and 3 deletions.
98 changes: 95 additions & 3 deletions python/lsst/pipe/tasks/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import numpy as np
import pandas as pd
import astropy.table
import astropy.utils.metadata

import lsst.geom
import lsst.pex.config as pexConfig
Expand Down Expand Up @@ -82,6 +83,95 @@ def flattenFilters(df, noDupCols=["coord_ra", "coord_dec"], camelCase=False, inp
return newDf


class TableVStack:
"""A helper class for stacking astropy tables without having them all in
memory at once.
Parameters
----------
capacity : `int`
Full size of the final table.
Notes
-----
Unlike `astropy.table.vstack`, this class requires all tables to have the
exact same columns (it's slightly more strict than even the
``join_type="exact"`` argument to `astropy.table.vstack`).
"""

def __init__(self, capacity):
self.index = 0
self.capacity = capacity
self.result = None

@classmethod
def from_handles(cls, handles):
"""Construct from an iterable of
`lsst.daf.butler.DeferredDatasetHandle`.
Parameters
----------
handles : `~collections.abc.Iterable` [ \
`lsst.daf.butler.DeferredDatasetHandle` ]
Iterable of handles. Must have a storage class that supports the
"rowcount" component, which is all that will be fetched.
Returns
-------
vstack : `TableVStack`
An instance of this class, initialized with capacity equal to the
sum of the rowcounts of all the given table handles.
"""
capacity = sum(handle.get("rowcount") for handle in handles)
return cls(capacity=capacity)

def extend(self, table):
"""Add a single table to the stack.
Parameters
----------
table : `astropy.table.Table`
An astropy table instance.
"""
if self.result is None:
self.result = astropy.table.Table()
for name in table.colnames:
column = table[name]
column_cls = type(column)
self.result[name] = column_cls.info.new_like([column], self.capacity, name=name)
self.index = len(table)
self.result.meta = table.meta.copy()
else:
next_index = self.index + len(table)
for name in table.colnames:
self.result[name][self.index:next_index] = table[name]
self.index = next_index
self.result.meta = astropy.utils.metadata.merge(self.result.meta, table.meta)

@classmethod
def vstack_handles(cls, handles):
"""Vertically stack tables represented by deferred dataset handles.
Parameters
----------
handles : `~collections.abc.Iterable` [ \
`lsst.daf.butler.DeferredDatasetHandle` ]
Iterable of handles. Must have the "ArrowAstropy" storage class
and identical columns.
Returns
-------
table : `astropy.table.Table`
Concatenated table with the same columns as each input table and
the rows of all of them.
"""
handles = tuple(handles) # guard against single-pass iterators
vstack = cls.from_handles(handles)
for handle in handles:
vstack.extend(handle.get())
return vstack.result


class WriteObjectTableConnections(pipeBase.PipelineTaskConnections,
defaultTemplates={"coaddName": "deep"},
dimensions=("tract", "patch", "skymap")):
Expand Down Expand Up @@ -932,6 +1022,7 @@ class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections,
storageClass="ArrowAstropy",
dimensions=("tract", "patch", "skymap"),
multiple=True,
deferLoad=True,
)
outputCatalog = connectionTypes.Output(
doc="Pre-tract horizontal concatenation of the input objectTables",
Expand Down Expand Up @@ -965,7 +1056,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
self.log.info("Concatenating %s per-patch Object Tables",
len(inputs["inputCatalogs"]))
table = astropy.table.vstack(inputs["inputCatalogs"], join_type="exact")
table = TableVStack.vstack_handles(inputs["inputCatalogs"])
butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs)


Expand Down Expand Up @@ -1142,7 +1233,8 @@ class ConsolidateSourceTableConnections(pipeBase.PipelineTaskConnections,
name="{catalogType}sourceTable",
storageClass="ArrowAstropy",
dimensions=("instrument", "visit", "detector"),
multiple=True
multiple=True,
deferLoad=True,
)
outputCatalog = connectionTypes.Output(
doc="Per-visit concatenation of Source Table",
Expand Down Expand Up @@ -1175,7 +1267,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
self.log.info("Concatenating %s per-detector Source Tables",
len(inputs["inputCatalogs"]))
table = astropy.table.vstack(inputs["inputCatalogs"], join_type="exact")
table = TableVStack.vstack_handles(inputs["inputCatalogs"])
butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs)


Expand Down

0 comments on commit b9a7f93

Please sign in to comment.