Skip to content

Commit

Permalink
Merge branch 'master' into update-tests-part-2
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana authored Dec 11, 2024
2 parents db0094c + c4997cc commit a9e76ef
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 2 deletions.
11 changes: 11 additions & 0 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const std::int64_t,
#if !defined(USE_ONEMKL_CUBLAS)
const bool,
#endif // !USE_ONEMKL_CUBLAS
const std::vector<sycl::event> &);

static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
Expand All @@ -74,7 +76,9 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
const std::int64_t ldb,
char *resultC,
const std::int64_t ldc,
#if !defined(USE_ONEMKL_CUBLAS)
const bool is_row_major,
#endif // !USE_ONEMKL_CUBLAS
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand Down Expand Up @@ -236,6 +240,7 @@ std::tuple<sycl::event, sycl::event, bool>
std::int64_t lda;
std::int64_t ldb;

// cuBLAS supports only column-major storage
#if defined(USE_ONEMKL_CUBLAS)
const bool is_row_major = false;

Expand Down Expand Up @@ -315,9 +320,15 @@ std::tuple<sycl::event, sycl::event, bool>
const char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

#if defined(USE_ONEMKL_CUBLAS)
sycl::event gemm_ev =
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
#else
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
a_typeless_ptr, lda, b_typeless_ptr, ldb,
r_typeless_ptr, ldc, is_row_major, depends);
#endif // USE_ONEMKL_CUBLAS

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});
Expand Down
12 changes: 12 additions & 0 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
const char *,
const char *,
char *,
#if !defined(USE_ONEMKL_CUBLAS)
const bool,
#endif // !USE_ONEMKL_CUBLAS
const std::vector<sycl::event> &);

static gemm_batch_impl_fn_ptr_t
Expand All @@ -83,7 +85,9 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
const char *matrixA,
const char *matrixB,
char *resultC,
#if !defined(USE_ONEMKL_CUBLAS)
const bool is_row_major,
#endif // !USE_ONEMKL_CUBLAS
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand Down Expand Up @@ -311,6 +315,7 @@ std::tuple<sycl::event, sycl::event, bool>
std::int64_t lda;
std::int64_t ldb;

// cuBLAS supports only column-major storage
#if defined(USE_ONEMKL_CUBLAS)
const bool is_row_major = false;

Expand Down Expand Up @@ -391,10 +396,17 @@ std::tuple<sycl::event, sycl::event, bool>
const char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

#if defined(USE_ONEMKL_CUBLAS)
sycl::event gemm_batch_ev =
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
strideb, stridec, transA, transB, a_typeless_ptr,
b_typeless_ptr, r_typeless_ptr, depends);
#else
sycl::event gemm_batch_ev =
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
strideb, stridec, transA, transB, a_typeless_ptr,
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
#endif // USE_ONEMKL_CUBLAS

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
Expand Down
11 changes: 11 additions & 0 deletions dpnp/backend/extensions/blas/gemv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const std::int64_t,
#if !defined(USE_ONEMKL_CUBLAS)
const bool,
#endif // !USE_ONEMKL_CUBLAS
const std::vector<sycl::event> &);

static gemv_impl_fn_ptr_t gemv_dispatch_vector[dpctl_td_ns::num_types];
Expand All @@ -69,7 +71,9 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
const std::int64_t incx,
char *vectorY,
const std::int64_t incy,
#if !defined(USE_ONEMKL_CUBLAS)
const bool is_row_major,
#endif // !USE_ONEMKL_CUBLAS
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);
Expand Down Expand Up @@ -190,6 +194,7 @@ std::pair<sycl::event, sycl::event>
oneapi::mkl::transpose transA;
std::size_t src_nelems;

// cuBLAS supports only column-major storage
#if defined(USE_ONEMKL_CUBLAS)
const bool is_row_major = false;
std::int64_t m;
Expand Down Expand Up @@ -299,9 +304,15 @@ std::pair<sycl::event, sycl::event>
y_typeless_ptr -= (y_shape[0] - 1) * std::abs(incy) * y_elemsize;
}

#if defined(USE_ONEMKL_CUBLAS)
sycl::event gemv_ev =
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
y_typeless_ptr, incy, depends);
#else
sycl::event gemv_ev =
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
y_typeless_ptr, incy, is_row_major, depends);
#endif // USE_ONEMKL_CUBLAS

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, vectorX, vectorY}, {gemv_ev});
Expand Down
36 changes: 34 additions & 2 deletions dpnp/tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return xp.nanargmin(a, axis=1)

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_nanargmin_out_float_dtype(self, xp, dtype):
a = xp.array([[0.0]])
b = xp.empty((1), dtype="int64")
xp.nanargmin(a, axis=1, out=b)
return b

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_array_equal()
def test_nanargmin_out_int_dtype(self, xp, dtype):
a = xp.array([1, 0])
b = xp.empty((), dtype="int64")
xp.nanargmin(a, out=b)
return b


class TestNanArgMax:

Expand Down Expand Up @@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return xp.nanargmax(a, axis=1)

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_nanargmax_out_float_dtype(self, xp, dtype):
a = xp.array([[0.0]])
b = xp.empty((1), dtype="int64")
xp.nanargmax(a, axis=1, out=b)
return b

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_array_equal()
def test_nanargmax_out_int_dtype(self, xp, dtype):
a = xp.array([0, 1])
b = xp.empty((), dtype="int64")
xp.nanargmax(a, out=b)
return b


@testing.parameterize(
*testing.product(
Expand Down Expand Up @@ -771,7 +803,7 @@ def test_invalid_sorter(self):

def test_nonint_sorter(self):
for xp in (numpy, cupy):
x = testing.shaped_arange((12,), xp, xp.float32)
x = testing.shaped_arange((12,), xp, xp.float64)
bins = xp.array([10, 4, 2, 1, 8])
sorter = xp.array([], dtype=xp.float32)
with pytest.raises((TypeError, ValueError)):
Expand Down Expand Up @@ -865,7 +897,7 @@ def test_invalid_sorter(self):

def test_nonint_sorter(self):
for xp in (numpy, cupy):
x = testing.shaped_arange((12,), xp, xp.float32)
x = testing.shaped_arange((12,), xp, xp.float64)
bins = xp.array([10, 4, 2, 1, 8])
sorter = xp.array([], dtype=xp.float32)
with pytest.raises((TypeError, ValueError)):
Expand Down

0 comments on commit a9e76ef

Please sign in to comment.