From 4fb30c8fc4374dd0295ac7c90a64ecf78f88744e Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 14 Oct 2024 12:04:24 +0100 Subject: [PATCH] Fix dtypes for qr (#594) --- cubed/array_api/linalg.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cubed/array_api/linalg.py b/cubed/array_api/linalg.py index 8ca57877..0cfcb35f 100644 --- a/cubed/array_api/linalg.py +++ b/cubed/array_api/linalg.py @@ -4,6 +4,7 @@ # These functions are in both the main and linalg namespaces from cubed.array_api.data_type_functions import result_type +from cubed.array_api.dtypes import _floating_dtypes from cubed.array_api.linear_algebra_functions import ( # noqa: F401 matmul, matrix_transpose, @@ -33,6 +34,9 @@ def qr(x, /, *, mode="reduced") -> QRResult: if mode != "reduced": raise ValueError("qr only supports mode='reduced'") + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in qr") + if x.numblocks[1] > 1: raise ValueError( "qr only supports tall-and-skinny (single column chunk) arrays. " @@ -80,7 +84,7 @@ def _qr_first_step(A): nxp.linalg.qr, A, shapes=[A.shape, R1_shape], - dtypes=[nxp.float64, nxp.float64], + dtypes=[A.dtype, A.dtype], chunkss=[A.chunks, R1_chunks], extra_projected_mem=extra_projected_mem, ) @@ -119,7 +123,7 @@ def _qr_second_step(R1): nxp.linalg.qr, R1_single, shapes=[Q2_shape, R2_shape], - dtypes=[nxp.float64, nxp.float64], + dtypes=[R1.dtype, R1.dtype], chunkss=[Q2_chunks, R2_chunks], extra_projected_mem=extra_projected_mem, ) @@ -148,7 +152,7 @@ def _qr_third_step(Q1, Q2): Q1, Q2, shape=Q1_shape, - dtype=nxp.float64, + dtype=result_type(Q1, Q2), chunks=Q1_chunks, extra_projected_mem=extra_projected_mem, q1_chunks=Q1_chunks,