Skip to content

Commit

Permalink
Fix dtypes for qr (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Oct 14, 2024
1 parent d361e00 commit 4fb30c8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# These functions are in both the main and linalg namespaces
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import _floating_dtypes
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
matmul,
matrix_transpose,
Expand Down Expand Up @@ -33,6 +34,9 @@ def qr(x, /, *, mode="reduced") -> QRResult:
if mode != "reduced":
raise ValueError("qr only supports mode='reduced'")

if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in qr")

if x.numblocks[1] > 1:
raise ValueError(
"qr only supports tall-and-skinny (single column chunk) arrays. "
Expand Down Expand Up @@ -80,7 +84,7 @@ def _qr_first_step(A):
nxp.linalg.qr,
A,
shapes=[A.shape, R1_shape],
dtypes=[nxp.float64, nxp.float64],
dtypes=[A.dtype, A.dtype],
chunkss=[A.chunks, R1_chunks],
extra_projected_mem=extra_projected_mem,
)
Expand Down Expand Up @@ -119,7 +123,7 @@ def _qr_second_step(R1):
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[nxp.float64, nxp.float64],
dtypes=[R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
Expand Down Expand Up @@ -148,7 +152,7 @@ def _qr_third_step(Q1, Q2):
Q1,
Q2,
shape=Q1_shape,
dtype=nxp.float64,
dtype=result_type(Q1, Q2),
chunks=Q1_chunks,
extra_projected_mem=extra_projected_mem,
q1_chunks=Q1_chunks,
Expand Down

0 comments on commit 4fb30c8

Please sign in to comment.