From 6244a5a1ad055793bb4d515dcff6d7e7cbe97a76 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 6 Nov 2024 15:28:09 -0500 Subject: [PATCH 1/2] ingest somacore classes --- .pre-commit-config.yaml | 3 +- apis/python/setup.py | 4 +- apis/python/src/tiledbsoma/_experiment.py | 8 +- apis/python/src/tiledbsoma/_indexer.py | 14 +- apis/python/src/tiledbsoma/_query.py | 850 +++++++++++++++++++++ apis/python/tests/test_experiment_query.py | 105 +-- 6 files changed, 925 insertions(+), 59 deletions(-) create mode 100644 apis/python/src/tiledbsoma/_query.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54cbc5314b..2807cf368a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,8 @@ repos: # Pandas 2.x types (e.g. `pd.Series[Any]`). See `_types.py` or https://github.com/single-cell-data/TileDB-SOMA/issues/2839 # for more info. - "pandas-stubs>=2" - - "somacore==1.0.23" + # Temporary, for PR: see https://github.com/single-cell-data/SOMA/pull/244 + - "git+https://github.com/single-cell-data/soma@9e81f07" - types-setuptools args: ["--config-file=apis/python/pyproject.toml", "apis/python/src", "apis/python/devtools"] pass_filenames: false diff --git a/apis/python/setup.py b/apis/python/setup.py index 3742b7a6cd..e014d0334d 100644 --- a/apis/python/setup.py +++ b/apis/python/setup.py @@ -336,8 +336,8 @@ def run(self): "pyarrow", "scanpy>=1.9.2", "scipy", - # Note: the somacore version is in .pre-commit-config.yaml too - "somacore==1.0.23", + # Temporary, for PR: see https://github.com/single-cell-data/SOMA/pull/244 + "somacore @ git+https://github.com/single-cell-data/soma@rw/abcs", "typing-extensions", # Note "-" even though `import typing_extensions` ], extras_require={ diff --git a/apis/python/src/tiledbsoma/_experiment.py b/apis/python/src/tiledbsoma/_experiment.py index 5a57662761..dcd297d29e 100644 --- a/apis/python/src/tiledbsoma/_experiment.py +++ b/apis/python/src/tiledbsoma/_experiment.py @@ -9,13 +9,13 @@ from typing import Optional from somacore import experiment, query -from typing_extensions import Self from . import _tdb_handles from ._collection import Collection, CollectionBase from ._dataframe import DataFrame from ._indexer import IntIndexer from ._measurement import Measurement +from ._query import ExperimentAxisQuery from ._scene import Scene from ._soma_object import AnySOMAObject @@ -83,13 +83,11 @@ def axis_query( # type: ignore *, obs_query: Optional[query.AxisQuery] = None, var_query: Optional[query.AxisQuery] = None, - ) -> query.ExperimentAxisQuery[Self]: # type: ignore + ) -> ExperimentAxisQuery: """Creates an axis query over this experiment. Lifecycle: Maturing. """ - # mypy doesn't quite understand descriptors so it issues a spurious - # error here. - return query.ExperimentAxisQuery( # type: ignore + return ExperimentAxisQuery( self, measurement_name, obs_query=obs_query or query.AxisQuery(), diff --git a/apis/python/src/tiledbsoma/_indexer.py b/apis/python/src/tiledbsoma/_indexer.py index 4f39b0b042..2dc92f37c2 100644 --- a/apis/python/src/tiledbsoma/_indexer.py +++ b/apis/python/src/tiledbsoma/_indexer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import List, Optional, Union import numpy as np import numpy.typing as npt @@ -11,9 +11,7 @@ from tiledbsoma import pytiledbsoma as clib from ._types import PDSeries - -if TYPE_CHECKING: - from .options import SOMATileDBContext +from .options import SOMATileDBContext IndexerDataType = Union[ npt.NDArray[np.int64], @@ -27,7 +25,7 @@ def tiledbsoma_build_index( - data: IndexerDataType, *, context: Optional["SOMATileDBContext"] = None + data: IndexerDataType, *, context: Optional[SOMATileDBContext] = None ) -> IndexLike: """Initialize re-indexer for provided indices (deprecated). @@ -54,7 +52,7 @@ class IntIndexer: """ def __init__( - self, data: IndexerDataType, *, context: Optional["SOMATileDBContext"] = None + self, data: IndexerDataType, *, context: Optional[SOMATileDBContext] = None ): """Initialize re-indexer for provided indices. @@ -73,7 +71,7 @@ def __init__( ) self._reindexer.map_locations(data) - def get_indexer(self, target: IndexerDataType) -> Any: + def get_indexer(self, target: IndexerDataType) -> npt.NDArray[np.intp]: """Compute underlying indices of index for target data. Compatible with Pandas' Index.get_indexer method. @@ -81,7 +79,7 @@ def get_indexer(self, target: IndexerDataType) -> Any: Args: target: Data to return re-index data for. """ - return ( + return ( # type: ignore[no-any-return] self._reindexer.get_indexer_pyarrow(target) if isinstance(target, (pa.Array, pa.ChunkedArray)) else self._reindexer.get_indexer_general(target) diff --git a/apis/python/src/tiledbsoma/_query.py b/apis/python/src/tiledbsoma/_query.py new file mode 100644 index 0000000000..888fcaf1ab --- /dev/null +++ b/apis/python/src/tiledbsoma/_query.py @@ -0,0 +1,850 @@ +# Copyright (c) 2021-2023 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2023 TileDB, Inc. +# +# Licensed under the MIT License. + +"""Implementation of a SOMA Experiment. +""" +import enum +from concurrent.futures import ThreadPoolExecutor +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + cast, + overload, +) + +import attrs +import numpy as np +import numpy.typing as npt +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pacomp +import scipy.sparse as sp +from anndata import AnnData +from somacore import ( + AxisQuery, + DataFrame, + ReadIter, + SparseRead, + query, +) +from somacore.data import _RO_AUTO +from somacore.options import ( + BatchSize, + PlatformConfig, + ReadPartitions, + ResultOrder, + ResultOrderStr, +) +from somacore.query import _fast_csr +from somacore.query.query import ( + AxisColumnNames, + Numpyable, +) +from somacore.query.types import IndexFactory, IndexLike +from typing_extensions import Self + +if TYPE_CHECKING: + from ._experiment import Experiment +from ._measurement import Measurement +from ._sparse_nd_array import SparseNDArray + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + +class _HasObsVar(Protocol[_T_co]): + """Something which has an ``obs`` and ``var`` field. + + Used to give nicer type inference in :meth:`Axis.getattr_from`. + """ + + @property + def obs(self) -> _T_co: ... + + @property + def var(self) -> _T_co: ... + + +class Axis(enum.Enum): + OBS = "obs" + VAR = "var" + + @property + def value(self) -> Literal["obs", "var"]: + return super().value # type: ignore[no-any-return] + + @overload + def getattr_from(self, __source: _HasObsVar[_T]) -> _T: ... + + @overload + def getattr_from( + self, __source: Any, *, pre: Literal[""], suf: Literal[""] + ) -> object: ... + + @overload + def getattr_from( + self, __source: Any, *, pre: str = ..., suf: str = ... + ) -> object: ... + + def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object: + """Equivalent to ``something.
``."""
+        return getattr(__source, pre + self.value + suf)
+
+    def getitem_from(
+        self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
+    ) -> _T:
+        """Equivalent to ``something[pre + "obs"/"var" + suf]``."""
+        return __source[pre + self.value + suf]
+
+
+@attrs.define
+class AxisIndexer(query.AxisIndexer):
+    """
+    Given a query, provides index-building services for obs/var axis.
+
+    Lifecycle: maturing
+    """
+
+    query: "ExperimentAxisQuery"
+    _index_factory: IndexFactory
+    _cached_obs: Optional[IndexLike] = None
+    _cached_var: Optional[IndexLike] = None
+
+    @property
+    def _obs_index(self) -> IndexLike:
+        """Private. Return an index for the ``obs`` axis."""
+        if self._cached_obs is None:
+            self._cached_obs = self._index_factory(self.query.obs_joinids().to_numpy())
+        return self._cached_obs
+
+    @property
+    def _var_index(self) -> IndexLike:
+        """Private. Return an index for the ``var`` axis."""
+        if self._cached_var is None:
+            self._cached_var = self._index_factory(self.query.var_joinids().to_numpy())
+        return self._cached_var
+
+    def by_obs(self, coords: Numpyable) -> npt.NDArray[np.intp]:
+        """Reindex the coords (soma_joinids) over the ``obs`` axis."""
+        return self._obs_index.get_indexer(_to_numpy(coords))
+
+    def by_var(self, coords: Numpyable) -> npt.NDArray[np.intp]:
+        """Reindex for the coords (soma_joinids) over the ``var`` axis."""
+        return self._var_index.get_indexer(_to_numpy(coords))
+
+
+def _to_numpy(it: Numpyable) -> npt.NDArray[np.int64]:
+    if isinstance(it, np.ndarray):
+        return it
+    return it.to_numpy()  # type: ignore[no-any-return]
+
+
+@attrs.define(frozen=True)
+class AxisQueryResult:
+    """The result of running :meth:`ExperimentAxisQuery.read`. Private."""
+
+    obs: pd.DataFrame
+    """Experiment.obs query slice, as a pandas DataFrame"""
+    var: pd.DataFrame
+    """Experiment.ms[...].var query slice, as a pandas DataFrame"""
+    X: sp.csr_matrix
+    """Experiment.ms[...].X[...] query slice, as a SciPy sparse.csr_matrix """
+    X_layers: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
+    """Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
+    obsm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.obsm query slice, as a numpy ndarray"""
+    obsp: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.obsp query slice, as a numpy ndarray"""
+    varm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.varm query slice, as a numpy ndarray"""
+    varp: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.varp query slice, as a numpy ndarray"""
+
+    def to_anndata(self) -> AnnData:
+        return AnnData(
+            X=self.X,
+            obs=self.obs,
+            var=self.var,
+            obsm=(self.obsm or None),
+            obsp=(self.obsp or None),
+            varm=(self.varm or None),
+            varp=(self.varp or None),
+            layers=(self.X_layers or None),
+        )
+
+
+class ExperimentAxisQuery:
+    """Axis-based query against a SOMA Experiment.
+
+    ExperimentAxisQuery allows easy selection and extraction of data from a
+    single :class:`Measurement` in an :class:`Experiment`, by obs/var (axis) coordinates
+    and/or value filter.
+
+    The primary use for this class is slicing :class:`Experiment` ``X`` layers by obs or
+    var value and/or coordinates. Slicing on :class:`SparseNDArray` ``X`` matrices is
+    supported; :class:`DenseNDArray` is not supported at this time.
+
+    IMPORTANT: this class is not thread-safe.
+
+    IMPORTANT: this query class assumes it can store the full result of both
+    axis dataframe queries in memory, and only provides incremental access to
+    the underlying X NDArray. API features such as ``n_obs`` and ``n_vars``
+    codify this in the API.
+
+    IMPORTANT: you must call ``close()`` on any instance of this class to
+    release underlying resources. The ExperimentAxisQuery is a context manager,
+    and it is recommended that you use the following pattern to make this easy
+    and safe::
+
+        with ExperimentAxisQuery(...) as query:
+            ...
+
+    This base query implementation is designed to work against any SOMA
+    implementation that fulfills the basic APIs. A SOMA implementation may
+    include a custom query implementation optimized for its own use.
+
+    Lifecycle: maturing
+    """
+
+    def __init__(
+        self,
+        experiment: "Experiment",
+        measurement_name: str,
+        *,
+        obs_query: AxisQuery = AxisQuery(),
+        var_query: AxisQuery = AxisQuery(),
+        index_factory: IndexFactory = pd.Index,
+    ):
+        if measurement_name not in experiment.ms:
+            raise ValueError("Measurement does not exist in the experiment")
+
+        # Users often like to pass `foo=None` and we should let them
+        obs_query = obs_query or AxisQuery()
+        var_query = var_query or AxisQuery()
+
+        self.experiment = experiment
+        self.measurement_name = measurement_name
+
+        self._matrix_axis_query = MatrixAxisQuery(obs=obs_query, var=var_query)
+        self._joinids = JoinIDCache(self)
+        self._indexer = AxisIndexer(
+            self,
+            index_factory=index_factory,
+        )
+        self._index_factory = index_factory
+        self._threadpool_: Optional[ThreadPoolExecutor] = None
+
+    def obs(
+        self,
+        *,
+        column_names: Optional[Sequence[str]] = None,
+        batch_size: BatchSize = BatchSize(),
+        partitions: Optional[ReadPartitions] = None,
+        result_order: ResultOrderStr = _RO_AUTO,
+        platform_config: Optional[PlatformConfig] = None,
+    ) -> ReadIter[pa.Table]:
+        """Returns ``obs`` as an `Arrow table
+        `_
+        iterator.
+
+        Lifecycle: maturing
+        """
+        obs_query = self._matrix_axis_query.obs
+        return self._obs_df.read(
+            obs_query.coords,
+            value_filter=obs_query.value_filter,
+            column_names=column_names,
+            batch_size=batch_size,
+            partitions=partitions,
+            result_order=result_order,
+            platform_config=platform_config,
+        )
+
+    def var(
+        self,
+        *,
+        column_names: Optional[Sequence[str]] = None,
+        batch_size: BatchSize = BatchSize(),
+        partitions: Optional[ReadPartitions] = None,
+        result_order: ResultOrderStr = _RO_AUTO,
+        platform_config: Optional[PlatformConfig] = None,
+    ) -> ReadIter[pa.Table]:
+        """Returns ``var`` as an `Arrow table
+        `_
+        iterator.
+
+        Lifecycle: maturing
+        """
+        var_query = self._matrix_axis_query.var
+        return self._var_df.read(
+            var_query.coords,
+            value_filter=var_query.value_filter,
+            column_names=column_names,
+            batch_size=batch_size,
+            partitions=partitions,
+            result_order=result_order,
+            platform_config=platform_config,
+        )
+
+    def obs_joinids(self) -> pa.IntegerArray:
+        """Returns ``obs`` ``soma_joinids`` as an Arrow array.
+
+        Lifecycle: maturing
+        """
+        return self._joinids.obs
+
+    def var_joinids(self) -> pa.IntegerArray:
+        """Returns ``var`` ``soma_joinids`` as an Arrow array.
+
+        Lifecycle: maturing
+        """
+        return self._joinids.var
+
+    @property
+    def n_obs(self) -> int:
+        """The number of ``obs`` axis query results.
+
+        Lifecycle: maturing
+        """
+        return len(self.obs_joinids())
+
+    @property
+    def n_vars(self) -> int:
+        """The number of ``var`` axis query results.
+
+        Lifecycle: maturing
+        """
+        return len(self.var_joinids())
+
+    @property
+    def indexer(self) -> AxisIndexer:
+        """A ``soma_joinid`` indexer for both ``obs`` and ``var`` axes.
+
+        Lifecycle: maturing
+        """
+        return self._indexer
+
+    def X(
+        self,
+        layer_name: str,
+        *,
+        batch_size: BatchSize = BatchSize(),
+        partitions: Optional[ReadPartitions] = None,
+        result_order: ResultOrderStr = _RO_AUTO,
+        platform_config: Optional[PlatformConfig] = None,
+    ) -> SparseRead:
+        """Returns an ``X`` layer as a sparse read.
+
+        Args:
+            layer_name: The X layer name to return.
+            batch_size: The size of batches that should be returned from a read.
+                See :class:`BatchSize` for details.
+            partitions: Specifies that this is part of a partitioned read,
+                and which partition to include, if present.
+            result_order: the order to return results, specified as a
+                :class:`~ResultOrder` or its string value.
+
+        Lifecycle: maturing
+        """
+        try:
+            x_layer = self._ms.X[layer_name]
+        except KeyError as ke:
+            raise KeyError(f"{layer_name} is not present in X") from ke
+        if not isinstance(x_layer, SparseNDArray):
+            raise TypeError("X layers may only be sparse arrays")
+
+        self._joinids.preload(self._threadpool)
+        return x_layer.read(
+            (self._joinids.obs, self._joinids.var),
+            batch_size=batch_size,
+            partitions=partitions,
+            result_order=result_order,
+            platform_config=platform_config,
+        )
+
+    def obsp(self, layer: str) -> SparseRead:
+        """Returns an ``obsp`` layer as a sparse read.
+
+        Lifecycle: maturing
+        """
+        return self._axisp_inner(Axis.OBS, layer)
+
+    def varp(self, layer: str) -> SparseRead:
+        """Returns a ``varp`` layer as a sparse read.
+
+        Lifecycle: maturing
+        """
+        return self._axisp_inner(Axis.VAR, layer)
+
+    def obsm(self, layer: str) -> SparseRead:
+        """Returns an ``obsm`` layer as a sparse read.
+        Lifecycle: maturing
+        """
+        return self._axism_inner(Axis.OBS, layer)
+
+    def varm(self, layer: str) -> SparseRead:
+        """Returns a ``varm`` layer as a sparse read.
+        Lifecycle: maturing
+        """
+        return self._axism_inner(Axis.VAR, layer)
+
+    def obs_scene_ids(self) -> pa.Array:
+        """Returns a pyarrow array with scene ids that contain obs from this
+        query.
+
+        Lifecycle: experimental
+        """
+        try:
+            obs_scene = self.experiment.obs_spatial_presence
+        except KeyError as ke:
+            raise KeyError("Missing obs_scene") from ke
+        if not isinstance(obs_scene, DataFrame):
+            raise TypeError("obs_scene must be a dataframe.")
+
+        full_table = obs_scene.read(
+            coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
+            result_order=ResultOrder.COLUMN_MAJOR,
+            value_filter="data != 0",
+        ).concat()
+
+        return pacomp.unique(full_table["scene_id"])
+
+    def var_scene_ids(self) -> pa.Array:
+        """Return a pyarrow array with scene ids that contain var from this
+        query.
+
+        Lifecycle: experimental
+        """
+        try:
+            var_scene = self._ms.var_spatial_presence
+        except KeyError as ke:
+            raise KeyError("Missing var_scene") from ke
+        if not isinstance(var_scene, DataFrame):
+            raise TypeError("var_scene must be a dataframe.")
+
+        full_table = var_scene.read(
+            coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
+            result_order=ResultOrder.COLUMN_MAJOR,
+            value_filter="data != 0",
+        ).concat()
+
+        return pacomp.unique(full_table["scene_id"])
+
+    def to_anndata(
+        self,
+        X_name: str,
+        *,
+        column_names: Optional[AxisColumnNames] = None,
+        X_layers: Sequence[str] = (),
+        obsm_layers: Sequence[str] = (),
+        obsp_layers: Sequence[str] = (),
+        varm_layers: Sequence[str] = (),
+        varp_layers: Sequence[str] = (),
+        drop_levels: bool = False,
+    ) -> AnnData:
+        ad = self._read(
+            X_name,
+            column_names=column_names or AxisColumnNames(obs=None, var=None),
+            X_layers=X_layers,
+            obsm_layers=obsm_layers,
+            obsp_layers=obsp_layers,
+            varm_layers=varm_layers,
+            varp_layers=varp_layers,
+        ).to_anndata()
+
+        # Drop unused categories on axis dataframes if requested
+        if drop_levels:
+            for name in ad.obs:
+                if ad.obs[name].dtype.name == "category":
+                    ad.obs[name] = ad.obs[name].cat.remove_unused_categories()
+            for name in ad.var:
+                if ad.var[name].dtype.name == "category":
+                    ad.var[name] = ad.var[name].cat.remove_unused_categories()
+
+        return ad
+
+    # Context management
+
+    def close(self) -> None:
+        """Releases resources associated with this query.
+
+        This method must be idempotent.
+
+        Lifecycle: maturing
+        """
+        # Because this may be called during ``__del__`` when we might be getting
+        # disassembled, sometimes ``_threadpool_`` is simply missing.
+        # Only try to shut it down if it still exists.
+        pool = getattr(self, "_threadpool_", None)
+        if pool is None:
+            return
+        pool.shutdown()
+        self._threadpool_ = None
+
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(self, *_: Any) -> None:
+        self.close()
+
+    def __del__(self) -> None:
+        """Ensure that we're closed when our last ref disappears."""
+        self.close()
+        # If any superclass in our MRO has a __del__, call it.
+        sdel = getattr(super(), "__del__", lambda: None)
+        sdel()
+
+    # Internals
+
+    def _read(
+        self,
+        X_name: str,
+        *,
+        column_names: AxisColumnNames,
+        X_layers: Sequence[str],
+        obsm_layers: Sequence[str] = (),
+        obsp_layers: Sequence[str] = (),
+        varm_layers: Sequence[str] = (),
+        varp_layers: Sequence[str] = (),
+    ) -> AxisQueryResult:
+        """Reads the entire query result in memory.
+
+        This is a low-level routine intended to be used by loaders for other
+        in-core formats, such as AnnData, which can be created from the
+        resulting objects.
+
+        Args:
+            X_name: The X layer to read and return in the ``X`` slot.
+            column_names: The columns in the ``var`` and ``obs`` dataframes
+                to read.
+            X_layers: Additional X layers to read and return
+                in the ``layers`` slot.
+            obsm_layers:
+                Additional obsm layers to read and return in the obsm slot.
+            obsp_layers:
+                Additional obsp layers to read and return in the obsp slot.
+            varm_layers:
+                Additional varm layers to read and return in the varm slot.
+            varp_layers:
+                Additional varp layers to read and return in the varp slot.
+        """
+        x_collection = self._ms.X
+        all_x_names = [X_name] + list(X_layers)
+        all_x_arrays: Dict[str, SparseNDArray] = {}
+        for _xname in all_x_names:
+            if not isinstance(_xname, str) or not _xname:
+                raise ValueError("X layer names must be specified as a string.")
+            if _xname not in x_collection:
+                raise ValueError("Unknown X layer name")
+            x_array = x_collection[_xname]
+            if not isinstance(x_array, SparseNDArray):
+                raise NotImplementedError("Dense array unsupported")
+            all_x_arrays[_xname] = x_array
+
+        def _read_axis_mappings(
+            fn: Callable[[Axis, str], npt.NDArray[Any]],
+            axis: Axis,
+            keys: Sequence[str],
+        ) -> Dict[str, npt.NDArray[Any]]:
+            return {key: fn(axis, key) for key in keys}
+
+        obsm_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axism_inner_ndarray, Axis.OBS, obsm_layers
+        )
+        obsp_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axisp_inner_ndarray, Axis.OBS, obsp_layers
+        )
+        varm_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axism_inner_ndarray, Axis.VAR, varm_layers
+        )
+        varp_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axisp_inner_ndarray, Axis.VAR, varp_layers
+        )
+
+        obs_table, var_table = self._read_both_axes(column_names)
+
+        obs_joinids = self.obs_joinids()
+        var_joinids = self.var_joinids()
+
+        x_matrices = {
+            _xname: (
+                _fast_csr.read_csr(
+                    layer,
+                    obs_joinids,
+                    var_joinids,
+                    index_factory=self._index_factory,
+                ).to_scipy()
+            )
+            for _xname, layer in all_x_arrays.items()
+        }
+
+        x = x_matrices.pop(X_name)
+
+        obs = obs_table.to_pandas()
+        obs.index = obs.index.astype(str)
+
+        var = var_table.to_pandas()
+        var.index = var.index.astype(str)
+
+        return AxisQueryResult(
+            obs=obs,
+            var=var,
+            X=x,
+            obsm=obsm_ft.result(),
+            obsp=obsp_ft.result(),
+            varm=varm_ft.result(),
+            varp=varp_ft.result(),
+            X_layers=x_matrices,
+        )
+
+    def _read_both_axes(
+        self,
+        column_names: AxisColumnNames,
+    ) -> Tuple[pa.Table, pa.Table]:
+        """Reads both axes in their entirety, ensuring soma_joinid is retained."""
+        obs_ft = self._threadpool.submit(
+            self._read_axis_dataframe,
+            Axis.OBS,
+            column_names,
+        )
+        var_ft = self._threadpool.submit(
+            self._read_axis_dataframe,
+            Axis.VAR,
+            column_names,
+        )
+        return obs_ft.result(), var_ft.result()
+
+    def _read_axis_dataframe(
+        self,
+        axis: Axis,
+        axis_column_names: AxisColumnNames,
+    ) -> pa.Table:
+        """Reads the specified axis. Will cache join IDs if not present."""
+        column_names = axis_column_names.get(axis.value)
+
+        axis_df = axis.getattr_from(self, pre="_", suf="_df")
+        assert isinstance(axis_df, DataFrame)
+        axis_query = axis.getattr_from(self._matrix_axis_query)
+
+        # If we can cache join IDs, prepare to add them to the cache.
+        joinids_cached = self._joinids._is_cached(axis)
+        query_columns = column_names
+        added_soma_joinid_to_columns = False
+        if (
+            not joinids_cached
+            and column_names is not None
+            and "soma_joinid" not in column_names
+        ):
+            # If we want to fill the join ID cache, ensure that we query the
+            # soma_joinid column so that it is included in the results.
+            # We'll filter it out later.
+            query_columns = ["soma_joinid"] + list(column_names)
+            added_soma_joinid_to_columns = True
+
+        # Do the actual query.
+        arrow_table = axis_df.read(
+            coords=axis_query.coords,
+            value_filter=axis_query.value_filter,
+            column_names=query_columns,
+        ).concat()
+
+        # Update the cache if needed. We can do this because no matter what
+        # other columns are queried for, the contents of the ``soma_joinid``
+        # column will be the same and can be safely stored.
+        if not joinids_cached:
+            setattr(
+                self._joinids,
+                axis.value,
+                arrow_table.column("soma_joinid").combine_chunks(),
+            )
+
+        # Drop soma_joinid column if we added it solely for use in filling
+        # the joinid cache.
+        if added_soma_joinid_to_columns:
+            arrow_table = arrow_table.drop(["soma_joinid"])
+        return arrow_table
+
+    def _axisp_inner(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> SparseRead:
+        p_name = f"{axis.value}p"
+        try:
+            ms = self._ms
+            axisp = ms.obsp if axis.value == "obs" else ms.varp
+        except (AttributeError, KeyError):
+            raise ValueError(f"Measurement does not contain {p_name} data")
+
+        try:
+            ap_layer = axisp[layer]
+        except KeyError:
+            raise ValueError(f"layer {layer!r} is not available in {p_name}")
+        if not isinstance(ap_layer, SparseNDArray):
+            raise TypeError(
+                f"Unexpected SOMA type {type(ap_layer).__name__}"
+                f" stored in {p_name} layer {layer!r}"
+            )
+
+        joinids = axis.getattr_from(self._joinids)
+        return ap_layer.read((joinids, joinids))
+
+    def _axism_inner(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> SparseRead:
+        m_name = f"{axis.value}m"
+
+        try:
+            ms = self._ms
+            axism = ms.obsm if axis.value == "obs" else ms.varm
+        except (AttributeError, KeyError):
+            raise ValueError(f"Measurement does not contain {m_name} data")
+
+        try:
+            axism_layer = axism[layer]
+        except KeyError:
+            raise ValueError(f"layer {layer!r} is not available in {m_name}")
+
+        if not isinstance(axism_layer, SparseNDArray):
+            raise TypeError(f"Unexpected SOMA type stored in '{m_name}' layer")
+
+        joinids = axis.getattr_from(self._joinids)
+        return axism_layer.read((joinids, slice(None)))
+
+    def _convert_to_ndarray(
+        self, axis: Axis, table: pa.Table, n_row: int, n_col: int
+    ) -> npt.NDArray[np.float32]:
+        indexer = cast(
+            Callable[[Numpyable], npt.NDArray[np.intp]],
+            axis.getattr_from(self.indexer, pre="by_"),
+        )
+        idx = indexer(table["soma_dim_0"])
+        z: npt.NDArray[np.float32] = np.zeros(n_row * n_col, dtype=np.float32)
+        np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
+        return z.reshape(n_row, n_col)
+
+    def _axisp_inner_ndarray(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> npt.NDArray[np.float32]:
+        n_row = n_col = len(axis.getattr_from(self._joinids))
+
+        table = self._axisp_inner(axis, layer).tables().concat()
+        return self._convert_to_ndarray(axis, table, n_row, n_col)
+
+    def _axism_inner_ndarray(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> npt.NDArray[np.float32]:
+        table = self._axism_inner(axis, layer).tables().concat()
+
+        n_row = len(axis.getattr_from(self._joinids))
+        n_col = len(table["soma_dim_1"].unique())
+
+        return self._convert_to_ndarray(axis, table, n_row, n_col)
+
+    @property
+    def _obs_df(self) -> DataFrame:
+        return self.experiment.obs
+
+    @property
+    def _ms(self) -> Measurement:
+        return self.experiment.ms[self.measurement_name]
+
+    @property
+    def _var_df(self) -> DataFrame:
+        return self._ms.var
+
+    @property
+    def _threadpool(self) -> ThreadPoolExecutor:
+        """
+        Returns the threadpool provided by the experiment's context.
+        If not available, creates a thread pool just in time."""
+        context = self.experiment.context
+        if context and context.threadpool:
+            return context.threadpool
+
+        if self._threadpool_ is None:
+            self._threadpool_ = ThreadPoolExecutor()
+        return self._threadpool_
+
+
+@attrs.define(frozen=True)
+class MatrixAxisQuery:
+    """The per-axis user query definition. Private."""
+
+    obs: AxisQuery
+    var: AxisQuery
+
+
+@attrs.define
+class JoinIDCache:
+    """A cache for per-axis join ids in the query. Private."""
+
+    owner: ExperimentAxisQuery
+
+    _cached_obs: Optional[pa.IntegerArray] = None
+    _cached_var: Optional[pa.IntegerArray] = None
+
+    def _is_cached(self, axis: Axis) -> bool:
+        field = "_cached_" + axis.value
+        return getattr(self, field) is not None
+
+    def preload(self, pool: ThreadPoolExecutor) -> None:
+        if self._cached_obs is not None and self._cached_var is not None:
+            return
+        obs_ft = pool.submit(lambda: self.obs)
+        var_ft = pool.submit(lambda: self.var)
+        # Wait for them and raise in case of error.
+        obs_ft.result()
+        var_ft.result()
+
+    @property
+    def obs(self) -> pa.IntegerArray:
+        """Join IDs for the obs axis. Will load and cache if not already."""
+        if not self._cached_obs:
+            self._cached_obs = load_joinids(
+                self.owner._obs_df, self.owner._matrix_axis_query.obs
+            )
+        return self._cached_obs
+
+    @obs.setter
+    def obs(self, val: pa.IntegerArray) -> None:
+        self._cached_obs = val
+
+    @property
+    def var(self) -> pa.IntegerArray:
+        """Join IDs for the var axis. Will load and cache if not already."""
+        if not self._cached_var:
+            self._cached_var = load_joinids(
+                self.owner._var_df, self.owner._matrix_axis_query.var
+            )
+        return self._cached_var
+
+    @var.setter
+    def var(self, val: pa.IntegerArray) -> None:
+        self._cached_var = val
+
+
+def load_joinids(df: DataFrame, axq: AxisQuery) -> pa.IntegerArray:
+    tbl = df.read(
+        axq.coords,
+        value_filter=axq.value_filter,
+        column_names=["soma_joinid"],
+    ).concat()
+    return tbl.column("soma_joinid").combine_chunks()
diff --git a/apis/python/tests/test_experiment_query.py b/apis/python/tests/test_experiment_query.py
index 1831e5e59c..401831f29d 100644
--- a/apis/python/tests/test_experiment_query.py
+++ b/apis/python/tests/test_experiment_query.py
@@ -4,6 +4,7 @@
 from typing import Tuple
 from unittest import mock
 
+import attrs
 import numpy as np
 import pandas as pd
 import pyarrow as pa
@@ -13,8 +14,10 @@
 from somacore import AxisQuery, options
 
 import tiledbsoma as soma
-from tiledbsoma import SOMATileDBContext, _factory
+from tiledbsoma import SOMATileDBContext, _factory, pytiledbsoma
 from tiledbsoma._collection import CollectionBase
+from tiledbsoma._experiment import Experiment
+from tiledbsoma._query import Axis, ExperimentAxisQuery
 from tiledbsoma.experiment_query import X_as_series
 
 from tests._util import raises_no_typeguard
@@ -58,7 +61,7 @@ def soma_experiment(
     varp_layer_names,
     obsm_layer_names,
     varm_layer_names,
-):
+) -> Experiment:
     with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp:
         add_dataframe(exp, "obs", n_obs)
         ms = exp.add_new_collection("ms")
@@ -88,7 +91,7 @@ def soma_experiment(
             for varm_layer_name in varm_layer_names:
                 add_sparse_array(varm, varm_layer_name, (n_vars, N_FEATURES))
 
-    return _factory.open((tmp_path / "exp").as_posix())
+    return Experiment.open((tmp_path / "exp").as_posix())
 
 
 def get_soma_experiment_with_context(soma_experiment, context):
@@ -99,7 +102,7 @@ def get_soma_experiment_with_context(soma_experiment, context):
 @pytest.mark.parametrize("n_obs,n_vars,X_layer_names", [(101, 11, ("raw", "extra"))])
 def test_experiment_query_all(soma_experiment):
     """Test a query with default obs_query / var_query -- i.e., query all."""
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         assert query.n_obs == 101
         assert query.n_vars == 11
 
@@ -304,7 +307,7 @@ def test_experiment_query_batch_size(soma_experiment):
     This test merely verifies that the batch_size parameter is accepted
     but as a no-op.
     """
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         tbls = query.obs(batch_size=options.BatchSize(count=100))
         assert len(list(tbls)) == 1  # batch_size currently not implemented
 
@@ -315,7 +318,7 @@ def test_experiment_query_partitions(soma_experiment):
     partitions is currently not supported by this implementation of SOMA.
     This test checks if a ValueError is raised if a partitioning is requested.
     """
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         with pytest.raises(ValueError):
             query.obs(partitions=options.IOfN(i=0, n=3)).concat()
 
@@ -328,7 +331,7 @@ def test_experiment_query_partitions(soma_experiment):
 
 @pytest.mark.parametrize("n_obs,n_vars", [(10, 10)])
 def test_experiment_query_result_order(soma_experiment):
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         # Since obs is 1-dimensional, row-major and column-major should be the same
         obs_data_row_major = (
             query.obs(result_order="row-major").concat()["label"].to_numpy()
@@ -393,11 +396,10 @@ def test_experiment_axis_query_with_none(soma_experiment):
     """Test query by value filter"""
     obs_label_values = ["3", "7", "38", "99"]
 
-    with soma.ExperimentAxisQuery(
+    with ExperimentAxisQuery(
         experiment=soma_experiment,
         measurement_name="RNA",
         obs_query=soma.AxisQuery(value_filter=f"label in {obs_label_values}"),
-        var_query=None,
     ) as query:
         assert query.n_obs == len(obs_label_values)
         assert query.obs().concat()["label"].to_pylist() == obs_label_values
@@ -462,7 +464,7 @@ def test_X_layers(soma_experiment):
 @pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
 def test_experiment_query_indexer(soma_experiment):
     """Test result indexer"""
-    with soma.ExperimentAxisQuery(
+    with ExperimentAxisQuery(
         soma_experiment,
         "RNA",
         obs_query=soma.AxisQuery(coords=(slice(1, 10),)),
@@ -501,7 +503,7 @@ def test_experiment_query_indexer(soma_experiment):
 
 
 @pytest.mark.parametrize("n_obs,n_vars", [(2833, 107)])
-def test_error_corners(soma_experiment: soma.Experiment):
+def test_error_corners(soma_experiment: Experiment):
     """Verify a couple of error conditions / corner cases."""
     # Unknown Measurement name
     with pytest.raises(ValueError):
@@ -532,16 +534,16 @@ def test_error_corners(soma_experiment: soma.Experiment):
         with soma_experiment.axis_query("RNA") as query:
             with raises_no_typeguard(KeyError):
                 next(query.X(lyr_name))
-            with pytest.raises(ValueError):
+            with raises_no_typeguard(ValueError):
                 next(query.obsp(lyr_name))
-            with pytest.raises(ValueError):
+            with raises_no_typeguard(ValueError):
                 next(query.varp(lyr_name))
 
 
 @pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
-def test_query_cleanup(soma_experiment: soma.Experiment):
+def test_query_cleanup(soma_experiment: Experiment):
     """
-    Verify soma.Experiment.query works as context manager and stand-alone,
+    Verify Experiment.query works as context manager and stand-alone,
     and that it cleans up correctly.
     """
     from contextlib import closing
@@ -574,16 +576,16 @@ def test_query_cleanup(soma_experiment: soma.Experiment):
 def test_experiment_query_obsp_varp_obsm_varm(soma_experiment):
     obs_slice = slice(3, 72)
     var_slice = slice(7, 21)
-    with soma.ExperimentAxisQuery(
+    with ExperimentAxisQuery(
         soma_experiment,
         "RNA",
-        obs_query=soma.AxisQuery(coords=(obs_slice,)),
-        var_query=soma.AxisQuery(coords=(var_slice,)),
+        obs_query=AxisQuery(coords=(obs_slice,)),
+        var_query=AxisQuery(coords=(var_slice,)),
     ) as query:
         assert query.n_obs == obs_slice.stop - obs_slice.start + 1
         assert query.n_vars == var_slice.stop - var_slice.start + 1
 
-        with pytest.raises(ValueError):
+        with raises_no_typeguard(ValueError):
             next(query.obsp("no-such-layer"))
 
         with pytest.raises(ValueError):
@@ -682,44 +684,40 @@ def test_experiment_query_to_anndata_obsp_varp(soma_experiment):
 
 def test_axis_query():
     """Basic test of the AxisQuery class"""
-    assert soma.AxisQuery().coords == ()
-    assert soma.AxisQuery().value_filter is None
-    assert soma.AxisQuery() == soma.AxisQuery(coords=())
+    assert AxisQuery().coords == ()
+    assert AxisQuery().value_filter is None
+    assert AxisQuery() == AxisQuery(coords=())
 
-    assert soma.AxisQuery(coords=(1,)).coords == (1,)
-    assert soma.AxisQuery(coords=(slice(1, 2),)).coords == (slice(1, 2),)
-    assert soma.AxisQuery(coords=((1, 88),)).coords == ((1, 88),)
+    assert AxisQuery(coords=(1,)).coords == (1,)
+    assert AxisQuery(coords=(slice(1, 2),)).coords == (slice(1, 2),)
+    assert AxisQuery(coords=((1, 88),)).coords == ((1, 88),)
 
-    assert soma.AxisQuery(coords=(1, 2)).coords == (1, 2)
-    assert soma.AxisQuery(coords=(slice(1, 2), slice(None))).coords == (
+    assert AxisQuery(coords=(1, 2)).coords == (1, 2)
+    assert AxisQuery(coords=(slice(1, 2), slice(None))).coords == (
         slice(1, 2),
         slice(None),
     )
-    assert soma.AxisQuery(coords=(slice(1, 2),)).value_filter is None
+    assert AxisQuery(coords=(slice(1, 2),)).value_filter is None
 
-    assert soma.AxisQuery(value_filter="foo == 'bar'").value_filter == "foo == 'bar'"
-    assert soma.AxisQuery(value_filter="foo == 'bar'").coords == ()
+    assert AxisQuery(value_filter="foo == 'bar'").value_filter == "foo == 'bar'"
+    assert AxisQuery(value_filter="foo == 'bar'").coords == ()
 
-    assert soma.AxisQuery(
-        coords=(slice(1, 100),), value_filter="foo == 'bar'"
-    ).coords == (
+    assert AxisQuery(coords=(slice(1, 100),), value_filter="foo == 'bar'").coords == (
         slice(1, 100),
     )
     assert (
-        soma.AxisQuery(
-            coords=(slice(1, 100),), value_filter="foo == 'bar'"
-        ).value_filter
+        AxisQuery(coords=(slice(1, 100),), value_filter="foo == 'bar'").value_filter
         == "foo == 'bar'"
     )
 
     with pytest.raises(TypeError):
-        soma.AxisQuery(coords=True)
+        AxisQuery(coords=True)
 
     with pytest.raises(TypeError):
-        soma.AxisQuery(value_filter=[])
+        AxisQuery(value_filter=[])
 
     with pytest.raises(TypeError):
-        soma.AxisQuery(coords=({},))
+        AxisQuery(coords=({},))
 
 
 def test_X_as_series():
@@ -801,10 +799,10 @@ def test_experiment_query_column_names(soma_experiment):
     # column_names and value_filter
     with soma_experiment.axis_query(
         "RNA",
-        obs_query=soma.AxisQuery(
+        obs_query=AxisQuery(
             value_filter="label in [" + ",".join(f"'{i}'" for i in range(101)) + "]"
         ),
-        var_query=soma.AxisQuery(
+        var_query=AxisQuery(
             value_filter="label in [" + ",".join(f"'{i}'" for i in range(99)) + "]"
         ),
     ) as query:
@@ -860,7 +858,7 @@ def test_experiment_query_mp_disjoint_arrow_coords(soma_experiment):
     for ids in slices:
         with soma_experiment.axis_query(
             "RNA",
-            obs_query=soma.AxisQuery(coords=(ids,)),
+            obs_query=AxisQuery(coords=(ids,)),
         ) as query:
             assert query.obs_joinids() == ids
 
@@ -951,7 +949,7 @@ def test_empty_categorical_query(conftest_pbmc_small_exp):
         measurement_name="RNA", obs_query=AxisQuery(value_filter='groups == "foo"')
     )
     # Empty query on a categorical column raised ArrowInvalid before TileDB 2.21; see https://github.com/single-cell-data/TileDB-SOMA/pull/2299
-    m = re.fullmatch(r"libtiledb=(\d+\.\d+\.\d+)", soma.pytiledbsoma.version())
+    m = re.fullmatch(r"libtiledb=(\d+\.\d+\.\d+)", pytiledbsoma.version())
     version = m.group(1).split(".")
     major, minor = int(version[0]), int(version[1])
 
@@ -959,3 +957,24 @@ def test_empty_categorical_query(conftest_pbmc_small_exp):
     with ctx:
         obs = q.obs().concat()
         assert len(obs) == 0
+
+
+@attrs.define(frozen=True)
+class IHaveObsVarStuff:
+    obs: int
+    var: int
+    the_obs_suf: str
+    the_var_suf: str
+
+
+def test_axis_helpers() -> None:
+    thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
+    assert 1 == Axis.OBS.getattr_from(thing)
+    assert 2 == Axis.VAR.getattr_from(thing)
+    assert "observe" == Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
+    assert "vary" == Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
+    ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
+    assert "erve" == Axis.OBS.getitem_from(ovdict)
+    assert "y" == Axis.VAR.getitem_from(ovdict)
+    assert "hide" == Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
+    assert "???" == Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")

From 9b5366f515d04b03201300fcb67b302eb8647553 Mon Sep 17 00:00:00 2001
From: Ryan Williams 
Date: Thu, 7 Nov 2024 14:14:03 -0500
Subject: [PATCH 2/2] ingest `_{{test_,}eager_iter,fast_csr}.py` from SOMA core

---
 .pre-commit-config.yaml                   |   2 +-
 apis/python/src/tiledbsoma/_eager_iter.py |  51 ++++
 apis/python/src/tiledbsoma/_fast_csr.py   | 302 ++++++++++++++++++++++
 apis/python/src/tiledbsoma/_query.py      |   4 +-
 apis/python/src/tiledbsoma/_read_iters.py |   2 +-
 apis/python/src/tiledbsoma/_types.py      |   2 +
 apis/python/tests/test_eager_iter.py      |  64 +++++
 7 files changed, 423 insertions(+), 4 deletions(-)
 create mode 100644 apis/python/src/tiledbsoma/_eager_iter.py
 create mode 100644 apis/python/src/tiledbsoma/_fast_csr.py
 create mode 100644 apis/python/tests/test_eager_iter.py

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2807cf368a..bc44cacf41 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -19,7 +19,7 @@ repos:
         # for more info.
         - "pandas-stubs>=2"
         # Temporary, for PR: see https://github.com/single-cell-data/SOMA/pull/244
-        - "git+https://github.com/single-cell-data/soma@9e81f07"
+        - "git+https://github.com/single-cell-data/soma@rw/abcs"
         - types-setuptools
       args: ["--config-file=apis/python/pyproject.toml", "apis/python/src", "apis/python/devtools"]
       pass_filenames: false
diff --git a/apis/python/src/tiledbsoma/_eager_iter.py b/apis/python/src/tiledbsoma/_eager_iter.py
new file mode 100644
index 0000000000..c42e52c1ab
--- /dev/null
+++ b/apis/python/src/tiledbsoma/_eager_iter.py
@@ -0,0 +1,51 @@
+from concurrent import futures
+from typing import Iterator, Optional, TypeVar
+
+_T = TypeVar("_T")
+
+
+class EagerIterator(Iterator[_T]):
+    def __init__(
+        self,
+        iterator: Iterator[_T],
+        pool: Optional[futures.Executor] = None,
+    ):
+        super().__init__()
+        self.iterator = iterator
+        self._pool = pool or futures.ThreadPoolExecutor()
+        self._own_pool = pool is None
+        self._preload_future = self._pool.submit(self.iterator.__next__)
+
+    def __next__(self) -> _T:
+        stopped = False
+        try:
+            if self._preload_future.cancel():
+                # If `.cancel` returns True, cancellation was successful.
+                # The self.iterator.__next__ call has not yet been started,
+                # and will never be started, so we can compute next ourselves.
+                # This prevents deadlocks if the thread pool is too small
+                # and we can never create a preload thread.
+                return next(self.iterator)
+            # `.cancel` returned false, so the preload is already running.
+            # Just wait for it.
+            return self._preload_future.result()
+        except StopIteration:
+            self._cleanup()
+            stopped = True
+            raise
+        finally:
+            if not stopped:
+                # If we have more to do, go for the next thing.
+                self._preload_future = self._pool.submit(self.iterator.__next__)
+
+    def _cleanup(self) -> None:
+        if self._own_pool:
+            self._pool.shutdown()
+
+    def __del__(self) -> None:
+        # Ensure the threadpool is cleaned up in the case where the
+        # iterator is not exhausted. For more information on __del__:
+        # https://docs.python.org/3/reference/datamodel.html#object.__del__
+        self._cleanup()
+        super_del = getattr(super(), "__del__", lambda: None)
+        super_del()
diff --git a/apis/python/src/tiledbsoma/_fast_csr.py b/apis/python/src/tiledbsoma/_fast_csr.py
new file mode 100644
index 0000000000..b4af4dc051
--- /dev/null
+++ b/apis/python/src/tiledbsoma/_fast_csr.py
@@ -0,0 +1,302 @@
+import os
+from concurrent.futures import Executor, ThreadPoolExecutor, wait
+from typing import Callable, List, NamedTuple, Tuple, Type, Union, cast
+
+import numba
+import numba.typed
+import numpy as np
+import numpy.typing as npt
+import pyarrow as pa
+import scipy.sparse as sp
+from somacore.query.types import IndexFactory, IndexLike
+
+from ._eager_iter import EagerIterator
+from ._funcs import _T, _Params
+from ._sparse_nd_array import SparseNDArray
+from ._types import NPIntArray, NPNDArray
+
+try:
+    # We need to `typeguard_ignore` the `@numba.jit`'d functions later in this file.
+    # However, for some reason `from ._funcs import typeguard_ignore` does not work here: tests raise errors like:
+    # ```
+    # E numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
+    # E Untyped global name 'TypeCheckMemo': Cannot determine Numba type of 
+    # E
+    # E File "src/tiledbsoma/_fast_csr.py", line 23:
+    # E
+    # E def read_csr(
+    # E ^
+    # ```
+    # Directly importing here is the only known solution.
+    from typeguard import typeguard_ignore
+except ImportError:
+    # Define a typeguard_ignore function so that we can use the `@typeguard_ignore`
+    # decorator without having to depend upon typeguard at runtime.
+    def typeguard_ignore(f: Callable[_Params, _T]) -> Callable[_Params, _T]:
+        """No-op. Returns the argument unchanged."""
+        return f
+
+
+def read_csr(
+    matrix: SparseNDArray,
+    obs_joinids: pa.IntegerArray,
+    var_joinids: pa.IntegerArray,
+    index_factory: IndexFactory,
+) -> "AccumulatedCSR":
+    if not isinstance(matrix, SparseNDArray) or matrix.ndim != 2:
+        raise TypeError("Can only read from a 2D SparseNDArray")
+
+    max_workers = (os.cpu_count() or 4) + 2
+    with ThreadPoolExecutor(max_workers=max_workers) as pool:
+        acc = _CSRAccumulator(
+            obs_joinids=obs_joinids,
+            var_joinids=var_joinids,
+            pool=pool,
+            index_factory=index_factory,
+        )
+        for tbl in EagerIterator(
+            matrix.read((obs_joinids, var_joinids)).tables(),
+            pool=pool,
+        ):
+            acc.append(tbl["soma_dim_0"], tbl["soma_dim_1"], tbl["soma_data"])
+
+        return acc.finalize()
+
+
+class AccumulatedCSR(NamedTuple):
+    """
+    Private.
+
+    Return type for the _CSRAccumulator.finalize method.
+    Contains a sparse CSR's constituent elements.
+    """
+
+    data: NPNDArray
+    indptr: NPIntArray
+    indices: NPIntArray
+    shape: Tuple[int, int]
+
+    def to_scipy(self) -> sp.csr_matrix:
+        """Create a Scipy ``sparse.csr_matrix`` from component elements.
+
+        Conceptually, this is identical to::
+
+            sparse.csr_matrix((data, indices, indptr), shape=shape)
+
+        This ugliness is to bypass the O(N) scan that
+        :meth:`sparse._cs_matrix.__init__`
+        does when a new compressed matrix is created.
+
+        See `SciPy bug 11496 `
+        for details.
+        """
+        matrix = sp.csr_matrix.__new__(sp.csr_matrix)
+        matrix.data = self.data
+        matrix.indptr = self.indptr
+        matrix.indices = self.indices
+        matrix._shape = self.shape
+        return matrix
+
+
+class _CSRAccumulator:
+    """
+    Fast accumulator of a CSR, based upon COO input.
+    """
+
+    def __init__(
+        self,
+        obs_joinids: pa.IntegerArray,
+        var_joinids: pa.IntegerArray,
+        pool: Executor,
+        index_factory: IndexFactory,
+    ):
+        self.obs_joinids = obs_joinids
+        self.var_joinids = var_joinids
+        self.pool = pool
+
+        self.shape: Tuple[int, int] = (len(self.obs_joinids), len(self.var_joinids))
+        self.obs_indexer = index_factory(self.obs_joinids)
+        self.var_indexer = index_factory(self.var_joinids)
+        self.row_length: NPIntArray = np.zeros(
+            (self.shape[0],), dtype=_select_dtype(self.shape[1])
+        )
+
+        # COO accumulated chunks, stored as list of triples (row_ind, col_ind, data)
+        self.coo_chunks: List[
+            Tuple[
+                NPIntArray,  # row_ind
+                NPIntArray,  # col_ind
+                NPNDArray,  # data
+            ]
+        ] = []
+
+    def append(
+        self,
+        row_joinids: Union[pa.Array, pa.ChunkedArray],
+        col_joinids: Union[pa.Array, pa.ChunkedArray],
+        data: Union[pa.Array, pa.ChunkedArray],
+    ) -> None:
+        """
+        At accumulation time, do several things:
+
+        * re-index to positional indices, and if possible, cast to smaller dtype
+          to minimize memory footprint (at cost of some amount of time)
+        * accumulate column counts by row, i.e., build the basis of the indptr
+        * cache the tuple of data, row, col
+        """
+        rows_future = self.pool.submit(
+            _reindex_and_cast,
+            self.obs_indexer,
+            row_joinids.to_numpy(),
+            _select_dtype(self.shape[0]),
+        )
+        cols_future = self.pool.submit(
+            _reindex_and_cast,
+            self.var_indexer,
+            col_joinids.to_numpy(),
+            _select_dtype(self.shape[1]),
+        )
+        row_ind = rows_future.result()
+        col_ind = cols_future.result()
+        self.coo_chunks.append((row_ind, col_ind, data.to_numpy()))  # type: ignore[arg-type]
+        _accum_row_length(self.row_length, row_ind)
+
+    def finalize(self) -> AccumulatedCSR:
+        nnz = sum(len(chunk[2]) for chunk in self.coo_chunks)
+        index_dtype = _select_dtype(nnz)
+        if nnz == 0:
+            # There is no way to infer matrix dtype, so use a default and return
+            # an empty matrix. Float32 is used as a default type, as it is most
+            # compatible with AnnData expectations.
+            empty = sp.csr_matrix((0, 0), dtype=np.float32)
+            return AccumulatedCSR(
+                data=empty.data,
+                indptr=empty.indptr,
+                indices=empty.indices,
+                shape=self.shape,
+            )
+
+        # cumsum row lengths to get indptr
+        indptr: NPIntArray = np.empty((self.shape[0] + 1,), dtype=index_dtype)
+        indptr[0:1] = 0
+        np.cumsum(self.row_length, out=indptr[1:])
+
+        # Parallel copy of data and column indices
+        indices: NPIntArray = np.empty((nnz,), dtype=index_dtype)
+        data: NPNDArray = np.empty((nnz,), dtype=self.coo_chunks[0][2].dtype)
+
+        # Empirically determined value. Needs to be large enough for reasonable
+        # concurrency, without excessive write cache conflict. Controls the
+        # number of rows that are processed in a single thread, and therefore
+        # is the primary tuning parameter related to concurrency.
+        row_rng_mask_bits = 18
+
+        n_jobs = (self.shape[0] >> row_rng_mask_bits) + 1
+        chunk_list = numba.typed.List(self.coo_chunks)
+        wait(
+            [
+                self.pool.submit(
+                    _copy_chunklist_range,
+                    chunk_list,
+                    data,
+                    indices,
+                    indptr,
+                    row_rng_mask_bits,
+                    job,
+                )
+                for job in range(n_jobs)
+            ]
+        )
+        _finalize_indptr(indptr)
+        return AccumulatedCSR(
+            data=data, indptr=indptr, indices=indices, shape=self.shape
+        )
+
+
+@typeguard_ignore  # type: ignore[misc]
+@numba.jit(nopython=True, nogil=True)  # type: ignore[attr-defined,misc]
+def _accum_row_length(
+    row_length: npt.NDArray[np.int64], row_ind: npt.NDArray[np.int64]
+) -> None:
+    for rind in row_ind:
+        row_length[rind] += 1
+
+
+@typeguard_ignore  # type: ignore[misc]
+@numba.jit(nopython=True, nogil=True)  # type: ignore[attr-defined,misc]
+def _copy_chunk_range(
+    row_ind_chunk: npt.NDArray[np.signedinteger[npt.NBitBase]],
+    col_ind_chunk: npt.NDArray[np.signedinteger[npt.NBitBase]],
+    data_chunk: NPNDArray,
+    data: NPNDArray,
+    indices: npt.NDArray[np.signedinteger[npt.NBitBase]],
+    indptr: npt.NDArray[np.signedinteger[npt.NBitBase]],
+    row_rng_mask: int,
+    row_rng_val: int,
+) -> None:
+    for n in range(len(data_chunk)):
+        row = row_ind_chunk[n]
+        if (row & row_rng_mask) != row_rng_val:
+            continue
+        ptr = indptr[row]
+        indices[ptr] = col_ind_chunk[n]
+        data[ptr] = data_chunk[n]
+        indptr[row] += 1
+
+
+@typeguard_ignore  # type: ignore[misc]
+@numba.jit(nopython=True, nogil=True)  # type: ignore[attr-defined,misc]
+def _copy_chunklist_range(
+    chunk_list: numba.typed.List,
+    data: NPNDArray,
+    indices: npt.NDArray[np.signedinteger[npt.NBitBase]],
+    indptr: npt.NDArray[np.signedinteger[npt.NBitBase]],
+    row_rng_mask_bits: int,
+    job: int,
+) -> None:
+    assert row_rng_mask_bits >= 1 and row_rng_mask_bits < 64
+    row_rng_mask = (2**64 - 1) >> row_rng_mask_bits << row_rng_mask_bits
+    row_rng_val = job << row_rng_mask_bits
+    for row_ind_chunk, col_ind_chunk, data_chunk in chunk_list:
+        _copy_chunk_range(
+            row_ind_chunk,
+            col_ind_chunk,
+            data_chunk,
+            data,
+            indices,
+            indptr,
+            row_rng_mask,
+            row_rng_val,
+        )
+
+
+@typeguard_ignore  # type: ignore[misc]
+@numba.jit(nopython=True, nogil=True)  # type: ignore[attr-defined,misc]
+def _finalize_indptr(indptr: npt.NDArray[np.signedinteger[npt.NBitBase]]) -> None:
+    prev = 0
+    for r in range(len(indptr)):
+        t = indptr[r]
+        indptr[r] = prev
+        prev = t
+
+
+def _select_dtype(
+    maxval: int,
+) -> Union[Type[np.int32], Type[np.int64]]:
+    """
+    Ascertain the "best" dtype for a zero-based index. Given our
+    goal of minimizing memory use, "best" is currently defined as
+    smallest.
+    """
+    if maxval > np.iinfo(np.int32).max:
+        return np.int64
+    else:
+        return np.int32
+
+
+def _reindex_and_cast(
+    index: IndexLike, ids: npt.NDArray[np.int64], target_dtype: npt.DTypeLike
+) -> npt.NDArray[np.int64]:
+    return cast(
+        npt.NDArray[np.int64], index.get_indexer(ids).astype(target_dtype, copy=False)
+    )
diff --git a/apis/python/src/tiledbsoma/_query.py b/apis/python/src/tiledbsoma/_query.py
index 888fcaf1ab..915d582bf3 100644
--- a/apis/python/src/tiledbsoma/_query.py
+++ b/apis/python/src/tiledbsoma/_query.py
@@ -46,7 +46,6 @@
     ResultOrder,
     ResultOrderStr,
 )
-from somacore.query import _fast_csr
 from somacore.query.query import (
     AxisColumnNames,
     Numpyable,
@@ -56,6 +55,7 @@
 
 if TYPE_CHECKING:
     from ._experiment import Experiment
+from ._fast_csr import read_csr
 from ._measurement import Measurement
 from ._sparse_nd_array import SparseNDArray
 
@@ -579,7 +579,7 @@ def _read_axis_mappings(
 
         x_matrices = {
             _xname: (
-                _fast_csr.read_csr(
+                read_csr(
                     layer,
                     obs_joinids,
                     var_joinids,
diff --git a/apis/python/src/tiledbsoma/_read_iters.py b/apis/python/src/tiledbsoma/_read_iters.py
index 06faca45c0..660bda6469 100644
--- a/apis/python/src/tiledbsoma/_read_iters.py
+++ b/apis/python/src/tiledbsoma/_read_iters.py
@@ -29,12 +29,12 @@
 import somacore
 from scipy import sparse
 from somacore import options
-from somacore.query._eager_iter import EagerIterator
 
 # This package's pybind11 code
 import tiledbsoma.pytiledbsoma as clib
 
 from . import _util
+from ._eager_iter import EagerIterator
 from ._exception import SOMAError
 from ._indexer import IntIndexer
 from ._types import NTuple
diff --git a/apis/python/src/tiledbsoma/_types.py b/apis/python/src/tiledbsoma/_types.py
index 4033c0d6fd..c298c40508 100644
--- a/apis/python/src/tiledbsoma/_types.py
+++ b/apis/python/src/tiledbsoma/_types.py
@@ -27,6 +27,7 @@
     NPInteger = np.integer[npt.NBitBase]
     NPFloating = np.floating[npt.NBitBase]
     NPNDArray = npt.NDArray[np.number[npt.NBitBase]]
+    NPIntArray = npt.NDArray[np.integer[npt.NBitBase]]
 else:
     # When not-type-checking, but running with `pandas>=2`, the "missing" type-params don't affect anything.
     PDSeries = pd.Series
@@ -38,6 +39,7 @@
     NPInteger = np.integer
     NPFloating = np.floating
     NPNDArray = np.ndarray
+    NPIntArray = np.ndarray
 
 
 Path = Union[str, pathlib.Path]
diff --git a/apis/python/tests/test_eager_iter.py b/apis/python/tests/test_eager_iter.py
new file mode 100644
index 0000000000..d5ebb925c3
--- /dev/null
+++ b/apis/python/tests/test_eager_iter.py
@@ -0,0 +1,64 @@
+import threading
+import unittest
+from concurrent import futures
+from unittest import mock
+
+from tiledbsoma._eager_iter import EagerIterator
+
+
+class EagerIterTest(unittest.TestCase):
+    def setUp(self):
+        super().setUp()
+        self.kiddie_pool = futures.ThreadPoolExecutor(1)
+        """Tiny thread pool for testing."""
+        self.verify_pool = futures.ThreadPoolExecutor(1)
+        """Separate thread pool so verification is not blocked."""
+
+    def tearDown(self):
+        self.verify_pool.shutdown(wait=False)
+        self.kiddie_pool.shutdown(wait=False)
+        super().tearDown()
+
+    def test_thread_starvation(self):
+        sem = threading.Semaphore()
+        try:
+            # Monopolize the threadpool.
+            sem.acquire()
+            self.kiddie_pool.submit(sem.acquire)
+            eager = EagerIterator(iter("abc"), pool=self.kiddie_pool)
+            got_a = self.verify_pool.submit(lambda: next(eager))
+            self.assertEqual("a", got_a.result(0.1))
+            got_b = self.verify_pool.submit(lambda: next(eager))
+            self.assertEqual("b", got_b.result(0.1))
+            got_c = self.verify_pool.submit(lambda: next(eager))
+            self.assertEqual("c", got_c.result(0.1))
+            with self.assertRaises(StopIteration):
+                self.verify_pool.submit(lambda: next(eager)).result(0.1)
+        finally:
+            sem.release()
+
+    def test_nesting(self):
+        inner = EagerIterator(iter("abc"), pool=self.kiddie_pool)
+        outer = EagerIterator(inner, pool=self.kiddie_pool)
+        self.assertEqual(
+            "a, b, c", self.verify_pool.submit(", ".join, outer).result(0.1)
+        )
+
+    def test_exceptions(self):
+        flaky = mock.MagicMock()
+        flaky.__next__.side_effect = [1, 2, ValueError(), 3, 4]
+
+        eager_flaky = EagerIterator(flaky, pool=self.kiddie_pool)
+        got_1 = self.verify_pool.submit(lambda: next(eager_flaky))
+        self.assertEqual(1, got_1.result(0.1))
+        got_2 = self.verify_pool.submit(lambda: next(eager_flaky))
+        self.assertEqual(2, got_2.result(0.1))
+        with self.assertRaises(ValueError):
+            self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)
+        got_3 = self.verify_pool.submit(lambda: next(eager_flaky))
+        self.assertEqual(3, got_3.result(0.1))
+        got_4 = self.verify_pool.submit(lambda: next(eager_flaky))
+        self.assertEqual(4, got_4.result(0.1))
+        for _ in range(5):
+            with self.assertRaises(StopIteration):
+                self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)