diff --git a/docs/release-notes/3392.feat.md b/docs/release-notes/3392.feat.md new file mode 100644 index 0000000000..db20ce7016 --- /dev/null +++ b/docs/release-notes/3392.feat.md @@ -0,0 +1 @@ +Add a `use_64_bit_elem_mul` option to {class}`scanpy.settings` for more accurate elem-wise multiplication {smaller}`Ilan Gold` diff --git a/src/scanpy/_settings.py b/src/scanpy/_settings.py index 54b51b6420..7d5eb8f430 100644 --- a/src/scanpy/_settings.py +++ b/src/scanpy/_settings.py @@ -118,6 +118,7 @@ def __init__( _vector_friendly: bool = False, _low_resolution_warning: bool = True, n_pcs=50, + use_64_bit_elem_mul=False, ): # logging self._root_logger = _RootLogger(logging.INFO) # level will be replaced @@ -156,6 +157,7 @@ def __init__( """Stores the previous memory usage.""" self.N_PCS = n_pcs + self.use_64_bit_elem_mul = use_64_bit_elem_mul @property def verbosity(self) -> Verbosity: @@ -412,6 +414,18 @@ def categories_to_ignore(self, categories_to_ignore: Iterable[str]): _type_check(cat, f"categories_to_ignore[{i}]", str) self._categories_to_ignore = categories_to_ignore + @property + def use_64_bit_elem_mul(self) -> bool: + """\ + Use a 64bit float buffer as the output target for element-wise multiplication. + """ + return self._use_64_bit_elem_mul + + @use_64_bit_elem_mul.setter + def use_64_bit_elem_mul(self, use_64_bit_elem_mul: bool): + _type_check(use_64_bit_elem_mul, "use_64_bit_elem_mul", bool) + self._use_64_bit_elem_mul = use_64_bit_elem_mul + # -------------------------------------------------------------------------------- # Functions # -------------------------------------------------------------------------------- diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 67e2ae03c8..ccd05be36b 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -570,6 +570,12 @@ def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]: _SupportedArray = _MemoryArray | DaskArray +def use_64_bit_float(arrays: Iterable[np.ndarray | DaskArray]) -> bool: + return settings.use_64_bit_elem_mul and any( + np.issubdtype(a.dtype, np.floating) for a in arrays + ) + + @singledispatch def elem_mul(x: _SupportedArray, y: _SupportedArray) -> _SupportedArray: raise NotImplementedError @@ -581,14 +587,16 @@ def _elem_mul_in_mem(x: _MemoryArray, y: _MemoryArray) -> _MemoryArray: if isinstance(x, sparse.spmatrix): # returns coo_matrix, so cast back to input type return type(x)(x.multiply(y)) - return x * y + return np.multiply(x, y, dtype=np.float64 if use_64_bit_float((x, y)) else None) @elem_mul.register(DaskArray) def _elem_mul_dask(x: DaskArray, y: DaskArray) -> DaskArray: import dask.array as da - return da.map_blocks(elem_mul, x, y) + return da.map_blocks( + elem_mul, x, y, dtype=np.float64 if use_64_bit_float((x, y)) else None + ) if TYPE_CHECKING: diff --git a/tests/test_utils.py b/tests/test_utils.py index f8a38a5f9d..55628dd492 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,7 @@ from packaging.version import Version from scipy.sparse import csr_matrix, issparse +from scanpy import settings from scanpy._compat import DaskArray, pkg_version from scanpy._utils import ( axis_mul_or_truediv, @@ -147,6 +148,26 @@ def test_elem_mul(array_type): np.testing.assert_array_equal(res, expd) +@pytest.mark.parametrize( + ("elem_dtype", "expected_dtype", "use_64_bit_elem_mul"), + [ + pytest.param(np.float32, np.float64, True, id="use_64_bit_elem_mul"), + pytest.param(np.float32, np.float32, False, id="use_default_elem_mul_dtype"), + pytest.param(np.int8, np.int8, True, id="ignore_64_bit_elem_for_int_elem"), + ], +) +def test_elem_mul_64_bit( + elem_dtype: np.dtype, expected_dtype: np.dtype, *, use_64_bit_elem_mul: bool +): + settings.use_64_bit_elem_mul = use_64_bit_elem_mul + m1 = np.array([[0, 1, 1], [1, 0, 1]], dtype=elem_dtype) + m2 = np.array([[2, 2, 1], [3, 2, 0]], dtype=elem_dtype) + expd = np.array([[0, 2, 1], [3, 0, 0]], dtype=expected_dtype) + res = elem_mul(m1, m2) + assert res.dtype == expected_dtype + np.testing.assert_array_equal(res, expd) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_axis_sum(array_type): m1 = array_type(asarray([[0, 1, 1], [1, 0, 1]]))