From dd39db31b8760093a35e8700dccfc493568789bd Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 10:51:45 +0200 Subject: [PATCH 1/7] add device_type in sycl --- accessor/sycl_helper.hpp | 192 +++++++++++++++++++++++++++++++++++++++ dpcpp/base/types.hpp | 125 +++++++++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 accessor/sycl_helper.hpp create mode 100644 dpcpp/base/types.hpp diff --git a/accessor/sycl_helper.hpp b/accessor/sycl_helper.hpp new file mode 100644 index 00000000000..793587c30d3 --- /dev/null +++ b/accessor/sycl_helper.hpp @@ -0,0 +1,192 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_ACCESSOR_SYCL_HELPER_HPP_ +#define GKO_ACCESSOR_SYCL_HELPER_HPP_ + + +#include +#include + +#include "block_col_major.hpp" +#include "reduced_row_major.hpp" +#include "row_major.hpp" +#include "scaled_reduced_row_major.hpp" +#include "utils.hpp" + + +namespace sycl { +inline namespace _V1 { + + +class half; + + +} +} // namespace sycl + + +namespace gko { + + +class half; + + +namespace acc { +namespace detail { + + +template +struct sycl_type { + using type = T; +}; + +template <> +struct sycl_type { + using type = sycl::half; +}; + +// Unpack cv and reference / pointer qualifiers +template +struct sycl_type { + using type = const typename sycl_type::type; +}; + +template +struct sycl_type { + using type = volatile typename sycl_type::type; +}; + +template +struct sycl_type { + using type = typename sycl_type::type*; +}; + +template +struct sycl_type { + using type = typename sycl_type::type&; +}; + +template +struct sycl_type { + using type = typename sycl_type::type&&; +}; + + +// Transform the underlying type of std::complex +template +struct sycl_type> { + using type = std::complex::type>; +}; + + +} // namespace detail + + +/** + * This is an alias for SYCL's equivalent of `T`. + * + * @tparam T a type + */ +template +using sycl_type_t = typename detail::sycl_type::type; + + +/** + * Reinterprets the passed in value as a SYCL type. + * + * @param val the value to reinterpret + * + * @return `val` reinterpreted to SYCL type + */ +template +std::enable_if_t::value || std::is_reference::value, + sycl_type_t> +as_sycl_type(T val) +{ + return reinterpret_cast>(val); +} + + +/** + * @copydoc as_sycl_type() + */ +template +std::enable_if_t::value && !std::is_reference::value, + sycl_type_t> +as_sycl_type(T val) +{ + return *reinterpret_cast*>(&val); +} + + +/** + * Changes the types and reinterprets the passed in range pointers as a SYCL + * types. + * + * @param r the range which pointers need to be reinterpreted + * + * @return `r` with appropriate types and reinterpreted to SYCL pointers + */ +template +GKO_ACC_INLINE auto as_sycl_range( + const range>& r) +{ + return range< + reduced_row_major, sycl_type_t>>( + r.get_accessor().get_size(), + as_sycl_type(r.get_accessor().get_stored_data()), + r.get_accessor().get_stride()); +} + +/** + * @copydoc as_sycl_range() + */ +template +GKO_ACC_INLINE auto as_sycl_range( + const range>& r) +{ + return range, + sycl_type_t, mask>>( + r.get_accessor().get_size(), + as_sycl_type(r.get_accessor().get_stored_data()), + r.get_accessor().get_storage_stride(), + as_sycl_type(r.get_accessor().get_scalar()), + r.get_accessor().get_scalar_stride()); +} + +/** + * @copydoc as_sycl_range() + */ +template +GKO_ACC_INLINE auto as_sycl_range(const range>& r) +{ + return range, dim>>( + r.get_accessor().lengths, as_sycl_type(r.get_accessor().data), + r.get_accessor().stride); +} + +/** + * @copydoc as_sycl_range() + */ +template +GKO_ACC_INLINE auto as_sycl_range(const range>& r) +{ + return range, dim>>( + r.get_accessor().lengths, as_sycl_type(r.get_accessor().data), + r.get_accessor().stride); +} + +template +GKO_ACC_INLINE auto as_device_range(AccType&& acc) +{ + return as_device_range(std::forward(acc)); +} + + +} // namespace acc +} // namespace gko + + +#endif // GKO_ACCESSOR_SYCL_HELPER_HPP_ diff --git a/dpcpp/base/types.hpp b/dpcpp/base/types.hpp new file mode 100644 index 00000000000..64c446c356e --- /dev/null +++ b/dpcpp/base/types.hpp @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_DPCPP_BASE_TYPES_HPP_ +#define GKO_DPCPP_BASE_TYPES_HPP_ + + +#include + +#include + +#include +#include + + +namespace gko { +namespace kernels { +namespace dpcpp { +namespace detail { + + +template +struct sycl_type_impl { + using type = T; +}; + +template +struct sycl_type_impl { + using type = typename sycl_type_impl::type*; +}; + +template +struct sycl_type_impl { + using type = typename sycl_type_impl::type&; +}; + +template +struct sycl_type_impl { + using type = const typename sycl_type_impl::type; +}; + +template +struct sycl_type_impl { + using type = volatile typename sycl_type_impl::type; +}; + +template <> +struct sycl_type_impl { + using type = sycl::half; +}; + +template +struct sycl_type_impl> { + using type = std::complex::type>; +}; + +} // namespace detail + + +/** + * This is an alias for SYCL's equivalent of `T`. + * + * @tparam T a type + */ +template +using sycl_type = typename detail::sycl_type_impl::type; + +/** + * This is an alias for SYCL/HIP's equivalent of `T` depending on the namespace. + * + * @tparam T a type + */ +template +using device_type = sycl_type; + + +/** + * Reinterprets the passed in value as a SYCL type. + * + * @param val the value to reinterpret + * + * @return `val` reinterpreted to SYCL type + */ +template +inline std::enable_if_t< + std::is_pointer::value || std::is_reference::value, sycl_type> +as_sycl_type(T val) +{ + return reinterpret_cast>(val); +} + + +/** + * @copydoc as_sycl_type() + */ +template +inline std::enable_if_t< + !std::is_pointer::value && !std::is_reference::value, sycl_type> +as_sycl_type(T val) +{ + return *reinterpret_cast*>(&val); +} + + +/** + * Reinterprets the passed in value as a SYCL/HIP type depending on the + * namespace. + * + * @param val the value to reinterpret + * + * @return `val` reinterpreted to SYCL/HIP type + */ +template +inline device_type as_device_type(T val) +{ + return as_sycl_type(val); +} + + +} // namespace dpcpp +} // namespace kernels +} // namespace gko + +#endif // GKO_DPCPP_BASE_TYPES_HPP_ From 6e0018faec8bd8448e7db54d1397103a2943063b Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 10:51:58 +0200 Subject: [PATCH 2/7] add device_type in kernel_launch --- common/unified/base/kernel_launch.hpp | 12 ++---------- dpcpp/base/kernel_launch_reduction.dp.hpp | 13 ++++++++----- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/common/unified/base/kernel_launch.hpp b/common/unified/base/kernel_launch.hpp index 455d3d67a6d..a1a25a8ca4f 100644 --- a/common/unified/base/kernel_launch.hpp +++ b/common/unified/base/kernel_launch.hpp @@ -74,16 +74,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr unpack_member_type unpack_member(T value) namespace gko { namespace kernels { namespace dpcpp { - - -template -using device_type = T; - -template -device_type as_device_type(T value) -{ - return value; -} +#include "dpcpp/base/types.hpp" template @@ -95,6 +86,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr unpack_member_type unpack_member(T value) return value; } + } // namespace dpcpp } // namespace kernels } // namespace gko diff --git a/dpcpp/base/kernel_launch_reduction.dp.hpp b/dpcpp/base/kernel_launch_reduction.dp.hpp index 83436966ecb..f45a92269a5 100644 --- a/dpcpp/base/kernel_launch_reduction.dp.hpp +++ b/dpcpp/base/kernel_launch_reduction.dp.hpp @@ -239,7 +239,8 @@ void run_kernel_reduction_cached(std::shared_ptr exec, [&](std::uint32_t cfg) { return cfg == desired_cfg; }, syn::value_list(), syn::value_list(), syn::value_list(), syn::type_list<>(), exec, fn, op, - finalize, identity, result, size, tmp, map_to_device(args)...); + finalize, as_device_type(identity), as_device_type(result), size, tmp, + map_to_device(args)...); } @@ -261,7 +262,8 @@ void run_kernel_reduction_cached(std::shared_ptr exec, [&](std::uint32_t cfg) { return cfg == desired_cfg; }, syn::value_list(), syn::value_list(), syn::value_list(), syn::type_list<>(), exec, fn, op, - finalize, identity, result, size, tmp, map_to_device(args)...); + finalize, as_device_type(identity), as_device_type(result), size, tmp, + map_to_device(args)...); } @@ -658,8 +660,8 @@ void run_kernel_row_reduction_cached(std::shared_ptr exec, [&](std::uint32_t cfg) { return cfg == desired_cfg; }, syn::value_list(), syn::value_list(), syn::value_list(), syn::type_list<>(), exec, fn, op, - finalize, identity, result, result_stride, size, tmp, - map_to_device(args)...); + finalize, as_device_type(identity), as_device_type(result), + result_stride, size, tmp, map_to_device(args)...); } @@ -681,7 +683,8 @@ void run_kernel_col_reduction_cached(std::shared_ptr exec, [&](std::uint32_t cfg) { return cfg == desired_cfg; }, syn::value_list(), syn::value_list(), syn::value_list(), syn::type_list<>(), exec, fn, op, - finalize, identity, result, size, tmp, map_to_device(args)...); + finalize, as_device_type(identity), as_device_type(result), size, tmp, + map_to_device(args)...); } From ae20a256baaa688dcc6a7b7bfb972e7dde4eca02 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 13:30:26 +0200 Subject: [PATCH 3/7] reduction sycl type --- dpcpp/components/reduction.dp.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dpcpp/components/reduction.dp.hpp b/dpcpp/components/reduction.dp.hpp index aed8166d601..933f6db7817 100644 --- a/dpcpp/components/reduction.dp.hpp +++ b/dpcpp/components/reduction.dp.hpp @@ -21,6 +21,7 @@ #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/dpct.hpp" #include "dpcpp/base/helper.hpp" +#include "dpcpp/base/types.hpp" #include "dpcpp/components/cooperative_groups.dp.hpp" #include "dpcpp/components/thread_ids.dp.hpp" #include "dpcpp/components/uninitialized_array.hpp" @@ -189,8 +190,9 @@ void reduce_add_array(dim3 grid, dim3 block, size_type dynamic_shared_memory, const ValueType* source, ValueType* result) { queue->submit([&](sycl::handler& cgh) { - sycl::local_accessor< - uninitialized_array, 0> + sycl::local_accessor, + DeviceConfig::block_size>, + 0> block_sum_acc_ct1(cgh); cgh.parallel_for( @@ -198,8 +200,8 @@ void reduce_add_array(dim3 grid, dim3 block, size_type dynamic_shared_memory, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(DeviceConfig::subgroup_size)]] { reduce_add_array( - size, source, result, item_ct1, - *block_sum_acc_ct1.get_pointer()); + size, as_device_type(source), as_device_type(result), + item_ct1, *block_sum_acc_ct1.get_pointer()); }); }); } From 6ec9fe9ab47502f01adfe16bd5d7571cbec37e0b Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 13:30:35 +0200 Subject: [PATCH 4/7] component sycy type --- dpcpp/base/device_matrix_data_kernels.dp.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/dpcpp/base/device_matrix_data_kernels.dp.cpp b/dpcpp/base/device_matrix_data_kernels.dp.cpp index f39615613fe..107d6fbea32 100644 --- a/dpcpp/base/device_matrix_data_kernels.dp.cpp +++ b/dpcpp/base/device_matrix_data_kernels.dp.cpp @@ -9,6 +9,7 @@ #include #include "dpcpp/base/onedpl.hpp" +#include "dpcpp/base/types.hpp" namespace gko { @@ -22,12 +23,13 @@ void remove_zeros(std::shared_ptr exec, array& values, array& row_idxs, array& col_idxs) { - using nonzero_type = matrix_data_entry; + using device_value_type = device_type; auto size = values.get_size(); auto policy = onedpl_policy(exec); - auto nnz = std::count_if( - policy, values.get_const_data(), values.get_const_data() + size, - [](ValueType val) { return is_nonzero(val); }); + auto nnz = + std::count_if(policy, as_device_type(values.get_const_data()), + as_device_type(values.get_const_data()) + size, + [](device_value_type val) { return is_nonzero(val); }); if (nnz < size) { // allocate new storage array new_values{exec, static_cast(nnz)}; @@ -36,10 +38,10 @@ void remove_zeros(std::shared_ptr exec, // copy nonzeros auto input_it = oneapi::dpl::make_zip_iterator( row_idxs.get_const_data(), col_idxs.get_const_data(), - values.get_const_data()); - auto output_it = oneapi::dpl::make_zip_iterator(new_row_idxs.get_data(), - new_col_idxs.get_data(), - new_values.get_data()); + as_device_type(values.get_const_data())); + auto output_it = oneapi::dpl::make_zip_iterator( + new_row_idxs.get_data(), new_col_idxs.get_data(), + as_device_type(new_values.get_data())); std::copy_if(policy, input_it, input_it + size, output_it, [](auto tuple) { return is_nonzero(std::get<2>(tuple)); }); // swap out storage From 1f870baad46ed39e248fd7427ceb74c699ebe931 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 13:30:56 +0200 Subject: [PATCH 5/7] matrix sycl type --- dpcpp/matrix/coo_kernels.dp.cpp | 35 ++-- dpcpp/matrix/csr_kernels.dp.cpp | 208 +++++++++++++---------- dpcpp/matrix/dense_kernels.dp.cpp | 45 ++--- dpcpp/matrix/diagonal_kernels.dp.cpp | 5 +- dpcpp/matrix/ell_kernels.dp.cpp | 15 +- dpcpp/matrix/sparsity_csr_kernels.dp.cpp | 13 +- 6 files changed, 183 insertions(+), 138 deletions(-) diff --git a/dpcpp/matrix/coo_kernels.dp.cpp b/dpcpp/matrix/coo_kernels.dp.cpp index 595af92b33b..71407f37f02 100644 --- a/dpcpp/matrix/coo_kernels.dp.cpp +++ b/dpcpp/matrix/coo_kernels.dp.cpp @@ -293,20 +293,22 @@ void spmv2(std::shared_ptr exec, const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols); int num_lines = ceildiv(nnz, nwarps * config::warp_size); abstract_spmv(coo_grid, coo_block, 0, exec->get_queue(), nnz, - num_lines, a->get_const_values(), + num_lines, as_device_type(a->get_const_values()), a->get_const_col_idxs(), a->get_const_row_idxs(), - b->get_const_values(), b->get_stride(), - c->get_values(), c->get_stride()); + as_device_type(b->get_const_values()), + b->get_stride(), as_device_type(c->get_values()), + c->get_stride()); } else { int num_elems = ceildiv(nnz, nwarps * config::warp_size) * config::warp_size; const dim3 coo_grid(ceildiv(nwarps, warps_in_block), ceildiv(b_ncols, config::warp_size)); abstract_spmm(coo_grid, coo_block, 0, exec->get_queue(), nnz, - num_elems, a->get_const_values(), + num_elems, as_device_type(a->get_const_values()), a->get_const_col_idxs(), a->get_const_row_idxs(), - b_ncols, b->get_const_values(), b->get_stride(), - c->get_values(), c->get_stride()); + b_ncols, as_device_type(b->get_const_values()), + b->get_stride(), as_device_type(c->get_values()), + c->get_stride()); } } } @@ -331,21 +333,24 @@ void advanced_spmv2(std::shared_ptr exec, int num_lines = ceildiv(nnz, nwarps * config::warp_size); const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols); abstract_spmv(coo_grid, coo_block, 0, exec->get_queue(), nnz, - num_lines, alpha->get_const_values(), - a->get_const_values(), a->get_const_col_idxs(), - a->get_const_row_idxs(), b->get_const_values(), - b->get_stride(), c->get_values(), c->get_stride()); + num_lines, as_device_type(alpha->get_const_values()), + as_device_type(a->get_const_values()), + a->get_const_col_idxs(), a->get_const_row_idxs(), + as_device_type(b->get_const_values()), + b->get_stride(), as_device_type(c->get_values()), + c->get_stride()); } else { int num_elems = ceildiv(nnz, nwarps * config::warp_size) * config::warp_size; const dim3 coo_grid(ceildiv(nwarps, warps_in_block), ceildiv(b_ncols, config::warp_size)); abstract_spmm(coo_grid, coo_block, 0, exec->get_queue(), nnz, - num_elems, alpha->get_const_values(), - a->get_const_values(), a->get_const_col_idxs(), - a->get_const_row_idxs(), b_ncols, - b->get_const_values(), b->get_stride(), - c->get_values(), c->get_stride()); + num_elems, as_device_type(alpha->get_const_values()), + as_device_type(a->get_const_values()), + a->get_const_col_idxs(), a->get_const_row_idxs(), + b_ncols, as_device_type(b->get_const_values()), + b->get_stride(), as_device_type(c->get_values()), + c->get_stride()); } } } diff --git a/dpcpp/matrix/csr_kernels.dp.cpp b/dpcpp/matrix/csr_kernels.dp.cpp index 7e5d0229c86..0c32fc55442 100644 --- a/dpcpp/matrix/csr_kernels.dp.cpp +++ b/dpcpp/matrix/csr_kernels.dp.cpp @@ -18,6 +18,7 @@ #include #include +#include "accessor/sycl_helper.hpp" #include "core/base/array_access.hpp" #include "core/base/mixed_precision_types.hpp" #include "core/base/utils.hpp" @@ -31,6 +32,7 @@ #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/dpct.hpp" #include "dpcpp/base/helper.hpp" +#include "dpcpp/base/types.hpp" #include "dpcpp/components/atomic.dp.hpp" #include "dpcpp/components/cooperative_groups.dp.hpp" #include "dpcpp/components/reduction.dp.hpp" @@ -1241,29 +1243,35 @@ void merge_path_spmv(syn::value_list, if (grid_num > 0) { csr::kernel::abstract_merge_path_spmv( grid, block, 0, exec->get_queue(), - static_cast(a->get_size()[0]), a_vals, - a->get_const_col_idxs(), a->get_const_row_ptrs(), - a->get_const_srow(), b_vals, c_vals, row_out.get_data(), - val_out.get_data()); + static_cast(a->get_size()[0]), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + a->get_const_row_ptrs(), a->get_const_srow(), + acc::as_device_range(b_vals), acc::as_device_range(c_vals), + row_out.get_data(), as_device_type(val_out.get_data())); } csr::kernel::abstract_reduce( 1, spmv_block_size, 0, exec->get_queue(), grid_num, - val_out.get_data(), row_out.get_data(), c_vals); + as_device_type(val_out.get_data()), row_out.get_data(), + acc::as_device_range(c_vals)); } else if (alpha != nullptr && beta != nullptr) { if (grid_num > 0) { csr::kernel::abstract_merge_path_spmv( grid, block, 0, exec->get_queue(), static_cast(a->get_size()[0]), - alpha->get_const_values(), a_vals, a->get_const_col_idxs(), - a->get_const_row_ptrs(), a->get_const_srow(), b_vals, - beta->get_const_values(), c_vals, row_out.get_data(), - val_out.get_data()); + as_device_type(alpha->get_const_values()), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + a->get_const_row_ptrs(), a->get_const_srow(), + acc::as_device_range(b_vals), + as_device_type(beta->get_const_values()), + acc::as_device_range(c_vals), row_out.get_data(), + as_device_type(val_out.get_data())); } - csr::kernel::abstract_reduce(1, spmv_block_size, 0, - exec->get_queue(), grid_num, - val_out.get_data(), row_out.get_data(), - alpha->get_const_values(), c_vals); + csr::kernel::abstract_reduce( + 1, spmv_block_size, 0, exec->get_queue(), grid_num, + as_device_type(val_out.get_data()), row_out.get_data(), + as_device_type(alpha->get_const_values()), + acc::as_device_range(c_vals)); } else { GKO_KERNEL_NOT_FOUND; } @@ -1317,17 +1325,20 @@ void classical_spmv(syn::value_list, if (alpha == nullptr && beta == nullptr) { if (grid.x > 0 && grid.y > 0) { kernel::abstract_classical_spmv( - grid, block, 0, exec->get_queue(), a->get_size()[0], a_vals, - a->get_const_col_idxs(), a->get_const_row_ptrs(), b_vals, - c_vals); + grid, block, 0, exec->get_queue(), a->get_size()[0], + acc::as_device_range(a_vals), a->get_const_col_idxs(), + a->get_const_row_ptrs(), acc::as_device_range(b_vals), + acc::as_device_range(c_vals)); } } else if (alpha != nullptr && beta != nullptr) { if (grid.x > 0 && grid.y > 0) { kernel::abstract_classical_spmv( grid, block, 0, exec->get_queue(), a->get_size()[0], - alpha->get_const_values(), a_vals, a->get_const_col_idxs(), - a->get_const_row_ptrs(), b_vals, beta->get_const_values(), - c_vals); + as_device_type(alpha->get_const_values()), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + a->get_const_row_ptrs(), acc::as_device_range(b_vals), + as_device_type(beta->get_const_values()), + acc::as_device_range(c_vals)); } } else { GKO_KERNEL_NOT_FOUND; @@ -1368,17 +1379,19 @@ void load_balance_spmv(std::shared_ptr exec, csr::kernel::abstract_spmv( csr_grid, csr_block, 0, exec->get_queue(), nwarps, static_cast(a->get_size()[0]), - alpha->get_const_values(), a_vals, a->get_const_col_idxs(), - a->get_const_row_ptrs(), a->get_const_srow(), b_vals, - c_vals); + as_device_type(alpha->get_const_values()), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + a->get_const_row_ptrs(), a->get_const_srow(), + acc::as_device_range(b_vals), acc::as_device_range(c_vals)); } } else { if (csr_grid.x > 0 && csr_grid.y > 0) { csr::kernel::abstract_spmv( csr_grid, csr_block, 0, exec->get_queue(), nwarps, - static_cast(a->get_size()[0]), a_vals, - a->get_const_col_idxs(), a->get_const_row_ptrs(), - a->get_const_srow(), b_vals, c_vals); + static_cast(a->get_size()[0]), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + a->get_const_row_ptrs(), a->get_const_srow(), + acc::as_device_range(b_vals), acc::as_device_range(c_vals)); } } } @@ -1711,9 +1724,10 @@ void compute_submatrix(std::shared_ptr exec, kernel::compute_submatrix_idxs_and_vals( grid_dim, block_dim, 0, exec->get_queue(), num_rows, num_cols, num_nnz, row_offset, col_offset, source->get_const_row_ptrs(), - source->get_const_col_idxs(), source->get_const_values(), + source->get_const_col_idxs(), + as_device_type(source->get_const_values()), result->get_const_row_ptrs(), result->get_col_idxs(), - result->get_values()); + as_device_type(result->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -1928,19 +1942,20 @@ void spgemm(std::shared_ptr exec, auto num_rows = a->get_size()[0]; const auto a_row_ptrs = a->get_const_row_ptrs(); const auto a_cols = a->get_const_col_idxs(); - const auto a_vals = a->get_const_values(); + const auto a_vals = as_device_type(a->get_const_values()); const auto b_row_ptrs = b->get_const_row_ptrs(); const auto b_cols = b->get_const_col_idxs(); - const auto b_vals = b->get_const_values(); + const auto b_vals = as_device_type(b->get_const_values()); auto c_row_ptrs = c->get_row_ptrs(); auto queue = exec->get_queue(); - array> heap_array( + using device_value_type = device_type; + array> heap_array( exec, a->get_num_stored_elements()); auto heap = heap_array.get_data(); auto col_heap = - reinterpret_cast*>(heap); + reinterpret_cast*>(heap); // first sweep: count nnz for each row queue->submit([&](sycl::handler& cgh) { @@ -1949,7 +1964,7 @@ void spgemm(std::shared_ptr exec, c_row_ptrs[a_row] = spgemm_multiway_merge( a_row, a_row_ptrs, a_cols, a_vals, b_row_ptrs, b_cols, b_vals, col_heap, [](size_type) { return IndexType{}; }, - [](ValueType, IndexType, IndexType&) {}, + [](device_value_type, IndexType, IndexType&) {}, [](IndexType, IndexType& nnz) { nnz++; }); }); }); @@ -1965,7 +1980,7 @@ void spgemm(std::shared_ptr exec, c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); - auto c_vals = c_vals_array.get_data(); + auto c_vals = as_device_type(c_vals_array.get_data()); queue->submit([&](sycl::handler& cgh) { cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) { @@ -1974,16 +1989,18 @@ void spgemm(std::shared_ptr exec, a_row, a_row_ptrs, a_cols, a_vals, b_row_ptrs, b_cols, b_vals, heap, [&](size_type row) { - return std::make_pair(zero(), c_row_ptrs[row]); + return std::make_pair(zero(), + c_row_ptrs[row]); }, - [](ValueType val, IndexType, - std::pair& state) { + [](device_value_type val, IndexType, + std::pair& state) { state.first += val; }, - [&](IndexType col, std::pair& state) { + [&](IndexType col, + std::pair& state) { c_col_idxs[state.second] = col; c_vals[state.second] = state.first; - state.first = zero(); + state.first = zero(); state.second++; }); }); @@ -2005,27 +2022,27 @@ void advanced_spgemm(std::shared_ptr exec, auto num_rows = a->get_size()[0]; const auto a_row_ptrs = a->get_const_row_ptrs(); const auto a_cols = a->get_const_col_idxs(); - const auto a_vals = a->get_const_values(); + const auto a_vals = as_device_type(a->get_const_values()); const auto b_row_ptrs = b->get_const_row_ptrs(); const auto b_cols = b->get_const_col_idxs(); - const auto b_vals = b->get_const_values(); + const auto b_vals = as_device_type(b->get_const_values()); const auto d_row_ptrs = d->get_const_row_ptrs(); const auto d_cols = d->get_const_col_idxs(); - const auto d_vals = d->get_const_values(); + const auto d_vals = as_device_type(d->get_const_values()); auto c_row_ptrs = c->get_row_ptrs(); - const auto alpha_vals = alpha->get_const_values(); - const auto beta_vals = beta->get_const_values(); + const auto alpha_vals = as_device_type(alpha->get_const_values()); + const auto beta_vals = as_device_type(beta->get_const_values()); constexpr auto sentinel = std::numeric_limits::max(); auto queue = exec->get_queue(); // first sweep: count nnz for each row - - array> heap_array( + using device_value_type = device_type; + array> heap_array( exec, a->get_num_stored_elements()); auto heap = heap_array.get_data(); auto col_heap = - reinterpret_cast*>(heap); + reinterpret_cast*>(heap); // first sweep: count nnz for each row queue->submit([&](sycl::handler& cgh) { @@ -2037,7 +2054,7 @@ void advanced_spgemm(std::shared_ptr exec, c_row_ptrs[a_row] = spgemm_multiway_merge( a_row, a_row_ptrs, a_cols, a_vals, b_row_ptrs, b_cols, b_vals, col_heap, [](size_type row) { return IndexType{}; }, - [](ValueType, IndexType, IndexType&) {}, + [](device_value_type, IndexType, IndexType&) {}, [&](IndexType col, IndexType& nnz) { // skip smaller elements from d while (d_col <= col) { @@ -2064,7 +2081,7 @@ void advanced_spgemm(std::shared_ptr exec, c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); - auto c_vals = c_vals_array.get_data(); + auto c_vals = as_device_type(c_vals_array.get_data()); queue->submit([&](sycl::handler& cgh) { cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) { @@ -2072,24 +2089,26 @@ void advanced_spgemm(std::shared_ptr exec, auto d_nz = d_row_ptrs[a_row]; const auto d_end = d_row_ptrs[a_row + 1]; auto d_col = checked_load(d_cols, d_nz, d_end, sentinel); - auto d_val = checked_load(d_vals, d_nz, d_end, zero()); - const auto valpha = alpha_vals[0]; - const auto vbeta = beta_vals[0]; + auto d_val = + checked_load(d_vals, d_nz, d_end, zero()); + const auto valpha = as_device_type(alpha_vals[0]); + const auto vbeta = as_device_type(beta_vals[0]); auto c_nz = spgemm_multiway_merge( a_row, a_row_ptrs, a_cols, a_vals, b_row_ptrs, b_cols, b_vals, heap, [&](size_type row) { - return std::make_pair(zero(), + return std::make_pair(zero(), c_row_ptrs[row]); }, - [](ValueType val, IndexType, - std::pair& state) { + [](device_value_type val, IndexType, + std::pair& state) { state.first += val; }, - [&](IndexType col, std::pair& state) { + [&](IndexType col, + std::pair& state) { // handle smaller elements from d - ValueType part_d_val{}; + device_value_type part_d_val{}; while (d_col <= col) { if (d_col == col) { part_d_val = d_val; @@ -2101,12 +2120,12 @@ void advanced_spgemm(std::shared_ptr exec, d_nz++; d_col = checked_load(d_cols, d_nz, d_end, sentinel); d_val = checked_load(d_vals, d_nz, d_end, - zero()); + zero()); } c_col_idxs[state.second] = col; c_vals[state.second] = vbeta * part_d_val + valpha * state.first; - state.first = zero(); + state.first = zero(); state.second++; }) .second; @@ -2117,7 +2136,8 @@ void advanced_spgemm(std::shared_ptr exec, c_nz++; d_nz++; d_col = checked_load(d_cols, d_nz, d_end, sentinel); - d_val = checked_load(d_vals, d_nz, d_end, zero()); + d_val = checked_load(d_vals, d_nz, d_end, + zero()); } }); }); @@ -2176,11 +2196,12 @@ void spgeam(std::shared_ptr exec, auto c_cols = c_col_idxs_array.get_data(); auto c_vals = c_vals_array.get_data(); - const auto a_vals = a->get_const_values(); - const auto b_vals = b->get_const_values(); - const auto alpha_vals = alpha->get_const_values(); - const auto beta_vals = beta->get_const_values(); + const auto a_vals = as_device_type(a->get_const_values()); + const auto b_vals = as_device_type(b->get_const_values()); + const auto alpha_vals = as_device_type(alpha->get_const_values()); + const auto beta_vals = as_device_type(beta->get_const_values()); + using device_value_type = device_type; // count number of non-zeros per row queue->submit([&](sycl::handler& cgh) { cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) { @@ -2197,8 +2218,10 @@ void spgeam(std::shared_ptr exec, const auto b_col = checked_load(b_cols, b_idx, b_end, sentinel); const bool use_a = a_col <= b_col; const bool use_b = b_col <= a_col; - const auto a_val = use_a ? a_vals[a_idx] : zero(); - const auto b_val = use_b ? b_vals[b_idx] : zero(); + const auto a_val = + use_a ? a_vals[a_idx] : zero(); + const auto b_val = + use_b ? b_vals[b_idx] : zero(); c_cols[c_nz] = std::min(a_col, b_col); c_vals[c_nz] = alpha * a_val + beta * b_val; c_nz++; @@ -2222,12 +2245,12 @@ void fill_in_dense(std::shared_ptr exec, const auto stride = result->get_stride(); const auto row_ptrs = source->get_const_row_ptrs(); const auto col_idxs = source->get_const_col_idxs(); - const auto vals = source->get_const_values(); + const auto vals = as_device_type(source->get_const_values()); auto grid_dim = ceildiv(num_rows, default_block_size); kernel::fill_in_dense(grid_dim, default_block_size, 0, exec->get_queue(), num_rows, row_ptrs, col_idxs, vals, stride, - result->get_values()); + as_device_type(result->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2254,13 +2277,13 @@ void generic_transpose(std::shared_ptr exec, auto queue = exec->get_queue(); const auto row_ptrs = orig->get_const_row_ptrs(); const auto cols = orig->get_const_col_idxs(); - const auto vals = orig->get_const_values(); + const auto vals = as_device_type(orig->get_const_values()); array counts{exec, num_cols + 1}; auto tmp_counts = counts.get_data(); auto out_row_ptrs = trans->get_row_ptrs(); auto out_cols = trans->get_col_idxs(); - auto out_vals = trans->get_values(); + auto out_vals = as_device_type(trans->get_values()); components::fill_array(exec, tmp_counts, num_cols, IndexType{}); queue->submit([&](sycl::handler& cgh) { @@ -2336,8 +2359,8 @@ void inv_symm_permute(std::shared_ptr exec, inv_symm_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, perm, orig->get_const_row_ptrs(), orig->get_const_col_idxs(), - orig->get_const_values(), permuted->get_row_ptrs(), - permuted->get_col_idxs(), permuted->get_values()); + as_deivice_type(orig->get_const_values()), permuted->get_row_ptrs(), + permuted->get_col_idxs(), as_device_type(permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2362,9 +2385,9 @@ void inv_nonsymm_permute(std::shared_ptr exec, inv_nonsymm_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, row_perm, col_perm, orig->get_const_row_ptrs(), - orig->get_const_col_idxs(), orig->get_const_values(), + orig->get_const_col_idxs(), as_deivice_type(orig->get_const_values()), permuted->get_row_ptrs(), permuted->get_col_idxs(), - permuted->get_values()); + as_deivice_type(permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2389,8 +2412,9 @@ void row_permute(std::shared_ptr exec, row_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, perm, orig->get_const_row_ptrs(), orig->get_const_col_idxs(), - orig->get_const_values(), row_permuted->get_row_ptrs(), - row_permuted->get_col_idxs(), row_permuted->get_values()); + as_deivice_type(orig->get_const_values()), row_permuted->get_row_ptrs(), + row_permuted->get_col_idxs(), + as_deivice_type(row_permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2415,8 +2439,9 @@ void inv_row_permute(std::shared_ptr exec, inv_row_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, perm, orig->get_const_row_ptrs(), orig->get_const_col_idxs(), - orig->get_const_values(), row_permuted->get_row_ptrs(), - row_permuted->get_col_idxs(), row_permuted->get_values()); + as_deivice_type(orig->get_const_values()), row_permuted->get_row_ptrs(), + row_permuted->get_col_idxs(), + as_deivice_type(row_permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2441,8 +2466,8 @@ void inv_symm_scale_permute(std::shared_ptr exec, inv_symm_scale_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, scale, perm, orig->get_const_row_ptrs(), orig->get_const_col_idxs(), - orig->get_const_values(), permuted->get_row_ptrs(), - permuted->get_col_idxs(), permuted->get_values()); + as_deivice_type(orig->get_const_values()), permuted->get_row_ptrs(), + permuted->get_col_idxs(), as_deivice_type(permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2470,9 +2495,9 @@ void inv_nonsymm_scale_permute(std::shared_ptr exec, inv_nonsymm_scale_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, row_scale, row_perm, col_scale, col_perm, orig->get_const_row_ptrs(), - orig->get_const_col_idxs(), orig->get_const_values(), + orig->get_const_col_idxs(), as_deivice_type(orig->get_const_values()), permuted->get_row_ptrs(), permuted->get_col_idxs(), - permuted->get_values()); + as_deivice_type(permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2497,8 +2522,9 @@ void row_scale_permute(std::shared_ptr exec, row_scale_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, scale, perm, orig->get_const_row_ptrs(), orig->get_const_col_idxs(), - orig->get_const_values(), row_permuted->get_row_ptrs(), - row_permuted->get_col_idxs(), row_permuted->get_values()); + as_deivice_type(orig->get_const_values()), row_permuted->get_row_ptrs(), + row_permuted->get_col_idxs(), + as_deivice_type(row_permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2523,8 +2549,9 @@ void inv_row_scale_permute(std::shared_ptr exec, inv_row_scale_permute_kernel( copy_num_blocks, default_block_size, 0, exec->get_queue(), num_rows, scale, perm, orig->get_const_row_ptrs(), orig->get_const_col_idxs(), - orig->get_const_values(), row_permuted->get_row_ptrs(), - row_permuted->get_col_idxs(), row_permuted->get_values()); + as_deivice_type(orig->get_const_values()), row_permuted->get_row_ptrs(), + row_permuted->get_col_idxs(), + as_deivice_type(row_permuted->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -2538,7 +2565,7 @@ void sort_by_column_index(std::shared_ptr exec, const auto num_rows = to_sort->get_size()[0]; const auto row_ptrs = to_sort->get_const_row_ptrs(); auto cols = to_sort->get_col_idxs(); - auto vals = to_sort->get_values(); + auto vals = as_deivice_type(to_sort->get_values()); exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) { const auto row = static_cast(idx[0]); @@ -2631,10 +2658,10 @@ void extract_diagonal(std::shared_ptr exec, const auto num_blocks = ceildiv(config::warp_size * diag_size, default_block_size); - const auto orig_values = orig->get_const_values(); + const auto orig_values = as_device_type(orig->get_const_values()); const auto orig_row_ptrs = orig->get_const_row_ptrs(); const auto orig_col_idxs = orig->get_const_col_idxs(); - auto diag_values = diag->get_values(); + auto diag_values = as_device_type(diag->get_values()); kernel::extract_diagonal(num_blocks, default_block_size, 0, exec->get_queue(), diag_size, nnz, orig_values, @@ -2683,9 +2710,10 @@ void add_scaled_identity(std::shared_ptr exec, const auto nblocks = ceildiv(nthreads, default_block_size); kernel::add_scaled_identity( nblocks, default_block_size, 0, exec->get_queue(), - alpha->get_const_values(), beta->get_const_values(), + as_deivice_type(alpha->get_const_values()), + as_deivice_type(beta->get_const_values()), static_cast(nrows), mtx->get_const_row_ptrs(), - mtx->get_const_col_idxs(), mtx->get_values()); + mtx->get_const_col_idxs(), as_deivice_type(mtx->get_values())); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/dpcpp/matrix/dense_kernels.dp.cpp b/dpcpp/matrix/dense_kernels.dp.cpp index 04f3229eaed..2edab5f2b57 100644 --- a/dpcpp/matrix/dense_kernels.dp.cpp +++ b/dpcpp/matrix/dense_kernels.dp.cpp @@ -22,6 +22,7 @@ #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/helper.hpp" #include "dpcpp/base/onemkl_bindings.hpp" +#include "dpcpp/base/types.hpp" #include "dpcpp/components/cooperative_groups.dp.hpp" #include "dpcpp/components/reduction.dp.hpp" #include "dpcpp/components/thread_ids.dp.hpp" @@ -103,9 +104,9 @@ void transpose(sycl::queue* queue, const matrix::Dense* orig, uninitialized_array, 0> space_acc_ct1(cgh); // Can not pass the member to device function directly - auto in = orig->get_const_values(); + auto in = as_device_type(orig->get_const_values()); auto in_stride = orig->get_stride(); - auto out = trans->get_values(); + auto out = as_device_type(trans->get_values()); auto out_stride = trans->get_stride(); cgh.parallel_for( sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { @@ -222,9 +223,10 @@ void simple_apply(std::shared_ptr exec, oneapi::mkl::blas::row_major::gemm( *exec->get_queue(), transpose::nontrans, transpose::nontrans, c->get_size()[0], c->get_size()[1], a->get_size()[1], - one(), a->get_const_values(), a->get_stride(), - b->get_const_values(), b->get_stride(), zero(), - c->get_values(), c->get_stride()); + one(), as_device_type(a->get_const_values()), + a->get_stride(), as_device_type(b->get_const_values()), + b->get_stride(), zero(), + as_device_type(c->get_values()), c->get_stride()); } else { dense::fill(exec, c, zero()); } @@ -247,10 +249,10 @@ void apply(std::shared_ptr exec, *exec->get_queue(), transpose::nontrans, transpose::nontrans, c->get_size()[0], c->get_size()[1], a->get_size()[1], exec->copy_val_to_host(alpha->get_const_values()), - a->get_const_values(), a->get_stride(), b->get_const_values(), - b->get_stride(), + as_device_type(a->get_const_values()), a->get_stride(), + as_device_type(b->get_const_values()), b->get_stride(), exec->copy_val_to_host(beta->get_const_values()), - c->get_values(), c->get_stride()); + as_device_type(c->get_values()), c->get_stride()); } else { dense::scale(exec, beta, c); } @@ -268,12 +270,12 @@ void convert_to_coo(std::shared_ptr exec, { const auto num_rows = result->get_size()[0]; const auto num_cols = result->get_size()[1]; - const auto in_vals = source->get_const_values(); + const auto in_vals = as_device_type(source->get_const_values()); const auto stride = source->get_stride(); auto rows = result->get_row_idxs(); auto cols = result->get_col_idxs(); - auto vals = result->get_values(); + auto vals = as_device_type(result->get_values()); exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for(num_rows, [=](sycl::item<1> item) { @@ -303,12 +305,12 @@ void convert_to_csr(std::shared_ptr exec, { const auto num_rows = result->get_size()[0]; const auto num_cols = result->get_size()[1]; - const auto in_vals = source->get_const_values(); + const auto in_vals = as_device_type(source->get_const_values()); const auto stride = source->get_stride(); const auto row_ptrs = result->get_const_row_ptrs(); auto cols = result->get_col_idxs(); - auto vals = result->get_values(); + auto vals = as_device_type(result->get_values()); exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for(num_rows, [=](sycl::item<1> item) { @@ -338,11 +340,11 @@ void convert_to_ell(std::shared_ptr exec, const auto num_rows = result->get_size()[0]; const auto num_cols = result->get_size()[1]; const auto max_nnz_per_row = result->get_num_stored_elements_per_row(); - const auto in_vals = source->get_const_values(); + const auto in_vals = as_device_type(source->get_const_values()); const auto in_stride = source->get_stride(); auto cols = result->get_col_idxs(); - auto vals = result->get_values(); + auto vals = as_device_type(result->get_values()); const auto stride = result->get_stride(); exec->get_queue()->submit([&](sycl::handler& cgh) { @@ -398,7 +400,7 @@ void convert_to_hybrid(std::shared_ptr exec, const auto num_rows = result->get_size()[0]; const auto num_cols = result->get_size()[1]; const auto ell_lim = result->get_ell_num_stored_elements_per_row(); - const auto in_vals = source->get_const_values(); + const auto in_vals = as_device_type(source->get_const_values()); const auto in_stride = source->get_stride(); const auto ell_stride = result->get_ell_stride(); auto ell_cols = result->get_ell_col_idxs(); @@ -453,11 +455,11 @@ void convert_to_sellp(std::shared_ptr exec, const auto num_rows = result->get_size()[0]; const auto num_cols = result->get_size()[1]; const auto stride = source->get_stride(); - const auto in_vals = source->get_const_values(); + const auto in_vals = as_device_type(source->get_const_values()); const auto slice_sets = result->get_const_slice_sets(); const auto slice_size = result->get_slice_size(); - auto vals = result->get_values(); + auto vals = as_device_type(result->get_values()); auto col_idxs = result->get_col_idxs(); exec->get_queue()->submit([&](sycl::handler& cgh) { @@ -495,7 +497,7 @@ void convert_to_sparsity_csr(std::shared_ptr exec, { const auto num_rows = result->get_size()[0]; const auto num_cols = result->get_size()[1]; - const auto in_vals = source->get_const_values(); + const auto in_vals = as_device_type(source->get_const_values()); const auto stride = source->get_stride(); const auto row_ptrs = result->get_const_row_ptrs(); @@ -560,9 +562,10 @@ void conj_transpose(std::shared_ptr exec, const auto sg_size = DCFG_1D::decode<1>(cfg); dim3 grid(ceildiv(size[1], sg_size), ceildiv(size[0], sg_size)); dim3 block(sg_size, sg_size); - kernel::conj_transpose_call(cfg, grid, block, 0, queue, size[0], size[1], - orig->get_const_values(), orig->get_stride(), - trans->get_values(), trans->get_stride()); + kernel::conj_transpose_call( + cfg, grid, block, 0, queue, size[0], size[1], + as_device_type(orig->get_const_values()), orig->get_stride(), + as_device_type(trans->get_values()), trans->get_stride()); } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); diff --git a/dpcpp/matrix/diagonal_kernels.dp.cpp b/dpcpp/matrix/diagonal_kernels.dp.cpp index 2b63138abbe..41931bab4f0 100644 --- a/dpcpp/matrix/diagonal_kernels.dp.cpp +++ b/dpcpp/matrix/diagonal_kernels.dp.cpp @@ -12,6 +12,7 @@ #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/helper.hpp" +#include "dpcpp/base/types.hpp" #include "dpcpp/components/cooperative_groups.dp.hpp" #include "dpcpp/components/thread_ids.dp.hpp" @@ -70,9 +71,9 @@ void apply_to_csr(std::shared_ptr exec, matrix::Csr* c, bool inverse) { const auto num_rows = b->get_size()[0]; - const auto diag_values = a->get_const_values(); + const auto diag_values = as_device_type(a->get_const_values()); c->copy_from(b); - auto csr_values = c->get_values(); + auto csr_values = as_device_type(c->get_values()); const auto csr_row_ptrs = c->get_const_row_ptrs(); const auto grid_dim = diff --git a/dpcpp/matrix/ell_kernels.dp.cpp b/dpcpp/matrix/ell_kernels.dp.cpp index a97cb602d52..a8f8e2e0c17 100644 --- a/dpcpp/matrix/ell_kernels.dp.cpp +++ b/dpcpp/matrix/ell_kernels.dp.cpp @@ -15,6 +15,7 @@ #include #include "accessor/reduced_row_major.hpp" +#include "accessor/sycl_helper.hpp" #include "core/base/mixed_precision_types.hpp" #include "core/components/fill_array_kernels.hpp" #include "core/components/prefix_sum_kernels.hpp" @@ -23,6 +24,7 @@ #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/helper.hpp" +#include "dpcpp/base/types.hpp" #include "dpcpp/components/atomic.dp.hpp" #include "dpcpp/components/cooperative_groups.dp.hpp" #include "dpcpp/components/format_conversion.dp.hpp" @@ -323,17 +325,20 @@ void abstract_spmv(syn::value_list, if (alpha == nullptr && beta == nullptr) { kernel::spmv( grid_size, block_size, 0, exec->get_queue(), nrows, - num_worker_per_row, a_vals, a->get_const_col_idxs(), stride, - num_stored_elements_per_row, b_vals, c->get_values(), + num_worker_per_row, acc::as_device_range(a_vals), + a->get_const_col_idxs(), stride, num_stored_elements_per_row, + acc::as_device_range(b_vals), as_device_type(c->get_values()), c->get_stride()); } else if (alpha != nullptr && beta != nullptr) { const auto alpha_val = gko::acc::range( std::array{1}, alpha->get_const_values()); kernel::spmv( grid_size, block_size, 0, exec->get_queue(), nrows, - num_worker_per_row, alpha_val, a_vals, a->get_const_col_idxs(), - stride, num_stored_elements_per_row, b_vals, - beta->get_const_values(), c->get_values(), c->get_stride()); + num_worker_per_row, acc::as_device_range(alpha_val), + acc::as_device_range(a_vals), a->get_const_col_idxs(), stride, + num_stored_elements_per_row, acc::as_device_range(b_vals), + as_device_type(beta->get_const_values()), + as_device_type(c->get_values()), c->get_stride()); } else { GKO_KERNEL_NOT_FOUND; } diff --git a/dpcpp/matrix/sparsity_csr_kernels.dp.cpp b/dpcpp/matrix/sparsity_csr_kernels.dp.cpp index 66c57ac5b35..ed3e457e81e 100644 --- a/dpcpp/matrix/sparsity_csr_kernels.dp.cpp +++ b/dpcpp/matrix/sparsity_csr_kernels.dp.cpp @@ -206,14 +206,17 @@ void classical_spmv(syn::value_list, if (alpha == nullptr && beta == nullptr) { kernel::abstract_classical_spmv( grid, block, 0, exec->get_queue(), a->get_size()[0], - a->get_const_value(), a->get_const_col_idxs(), - a->get_const_row_ptrs(), b_vals, c_vals); + as_device_type(a->get_const_value()), a->get_const_col_idxs(), + a->get_const_row_ptrs(), acc::as_device_range(b_vals), + acc::as_device_range(c_vals)); } else if (alpha != nullptr && beta != nullptr) { kernel::abstract_classical_spmv( grid, block, 0, exec->get_queue(), a->get_size()[0], - alpha->get_const_values(), a->get_const_value(), - a->get_const_col_idxs(), a->get_const_row_ptrs(), b_vals, - beta->get_const_values(), c_vals); + as_device_type(alpha->get_const_values()), a->get_const_value(), + a->get_const_col_idxs(), a->get_const_row_ptrs(), + acc::as_device_range(b_vals), + as_device_type(beta->get_const_values()), + acc::as_device_range(c_vals)); } else { GKO_KERNEL_NOT_FOUND; } From 0a155d028895a0cfb894f76b8f6d7606718d4034 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 13:31:14 +0200 Subject: [PATCH 6/7] factorization sycl type --- dpcpp/factorization/factorization_kernels.dp.cpp | 2 +- dpcpp/factorization/par_ic_kernels.dp.cpp | 14 +++++++------- dpcpp/factorization/par_ict_kernels.dp.cpp | 13 +++++++------ .../par_ilut_approx_filter_kernel.dp.cpp | 4 ++-- dpcpp/factorization/par_ilut_filter_kernel.dp.cpp | 2 +- dpcpp/factorization/par_ilut_select_kernel.dp.cpp | 2 +- dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp | 8 ++++---- dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp | 9 +++++---- 8 files changed, 28 insertions(+), 26 deletions(-) diff --git a/dpcpp/factorization/factorization_kernels.dp.cpp b/dpcpp/factorization/factorization_kernels.dp.cpp index 1d9912b4f12..04bd49c2c9a 100644 --- a/dpcpp/factorization/factorization_kernels.dp.cpp +++ b/dpcpp/factorization/factorization_kernels.dp.cpp @@ -496,7 +496,7 @@ void add_diagonal_elements(std::shared_ptr exec, array needs_change_device{exec, 1}; needs_change_device = needs_change_host; - auto dpcpp_old_values = mtx->get_const_values(); + auto dpcpp_old_values = as_device_type(mtx->get_const_values()); auto dpcpp_old_col_idxs = mtx->get_const_col_idxs(); auto dpcpp_old_row_ptrs = mtx->get_row_ptrs(); auto dpcpp_row_ptrs_add = row_ptrs_addition.get_data(); diff --git a/dpcpp/factorization/par_ic_kernels.dp.cpp b/dpcpp/factorization/par_ic_kernels.dp.cpp index 5428460fac5..0ae155a4f82 100644 --- a/dpcpp/factorization/par_ic_kernels.dp.cpp +++ b/dpcpp/factorization/par_ic_kernels.dp.cpp @@ -125,7 +125,7 @@ void init_factor(std::shared_ptr exec, auto num_rows = l->get_size()[0]; auto num_blocks = ceildiv(num_rows, default_block_size); auto l_row_ptrs = l->get_const_row_ptrs(); - auto l_vals = l->get_values(); + auto l_vals = as_device_type(l->get_values()); kernel::ic_init(num_blocks, default_block_size, 0, exec->get_queue(), l_row_ptrs, l_vals, num_rows); } @@ -143,12 +143,12 @@ void compute_factor(std::shared_ptr exec, auto nnz = l->get_num_stored_elements(); auto num_blocks = ceildiv(nnz, default_block_size); for (size_type i = 0; i < iterations; ++i) { - kernel::ic_sweep(num_blocks, default_block_size, 0, exec->get_queue(), - a_lower->get_const_row_idxs(), - a_lower->get_const_col_idxs(), - a_lower->get_const_values(), l->get_const_row_ptrs(), - l->get_const_col_idxs(), l->get_values(), - static_cast(l->get_num_stored_elements())); + kernel::ic_sweep( + num_blocks, default_block_size, 0, exec->get_queue(), + a_lower->get_const_row_idxs(), a_lower->get_const_col_idxs(), + a_lower->get_const_values(), l->get_const_row_ptrs(), + l->get_const_col_idxs(), as_device_type(l->get_values()), + static_cast(l->get_num_stored_elements())); } } diff --git a/dpcpp/factorization/par_ict_kernels.dp.cpp b/dpcpp/factorization/par_ict_kernels.dp.cpp index fb99b662dec..4f11bf7b7b1 100644 --- a/dpcpp/factorization/par_ict_kernels.dp.cpp +++ b/dpcpp/factorization/par_ict_kernels.dp.cpp @@ -402,13 +402,13 @@ void add_candidates(syn::value_list, matrix::CsrBuilder l_new_builder(l_new); auto llh_row_ptrs = llh->get_const_row_ptrs(); auto llh_col_idxs = llh->get_const_col_idxs(); - auto llh_vals = llh->get_const_values(); + auto llh_vals = as_device_type(llh->get_const_values()); auto a_row_ptrs = a->get_const_row_ptrs(); auto a_col_idxs = a->get_const_col_idxs(); - auto a_vals = a->get_const_values(); + auto a_vals = as_device_type(a->get_const_values()); auto l_row_ptrs = l->get_const_row_ptrs(); auto l_col_idxs = l->get_const_col_idxs(); - auto l_vals = l->get_const_values(); + auto l_vals = as_device_type(l->get_const_values()); auto l_new_row_ptrs = l_new->get_row_ptrs(); // count non-zeros per row kernel::ict_tri_spgeam_nnz( @@ -450,9 +450,10 @@ void compute_factor(syn::value_list, auto num_blocks = ceildiv(total_nnz, block_size); kernel::ict_sweep( num_blocks, default_block_size, 0, exec->get_queue(), - a->get_const_row_ptrs(), a->get_const_col_idxs(), a->get_const_values(), - l->get_const_row_ptrs(), l_coo->get_const_row_idxs(), - l->get_const_col_idxs(), l->get_values(), + a->get_const_row_ptrs(), a->get_const_col_idxs(), + as_device_type(a->get_const_values()), l->get_const_row_ptrs(), + l_coo->get_const_row_idxs(), l->get_const_col_idxs(), + as_device_type(l->get_values()), static_cast(l->get_num_stored_elements())); } diff --git a/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp b/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp index 776ffba3fb1..c808f7e0ae8 100644 --- a/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp @@ -58,7 +58,7 @@ void threshold_filter_approx(syn::value_list, matrix::Csr* m_out, matrix::Coo* m_out_coo) { - auto values = m->get_const_values(); + auto values = as_device_type(m->get_const_values()); IndexType size = m->get_num_stored_elements(); using AbsType = remove_complex; constexpr auto bucket_count = kernel::searchtree_width; @@ -102,7 +102,7 @@ void threshold_filter_approx(syn::value_list, // filter the elements auto old_row_ptrs = m->get_const_row_ptrs(); auto old_col_idxs = m->get_const_col_idxs(); - auto old_vals = m->get_const_values(); + auto old_vals = as_device_type(m->get_const_values()); // compute nnz for each row auto num_rows = static_cast(m->get_size()[0]); auto block_size = default_block_size / subgroup_size; diff --git a/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp b/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp index 5ce9df8a0a9..732a8dc6135 100644 --- a/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp @@ -57,7 +57,7 @@ void threshold_filter(syn::value_list, { auto old_row_ptrs = a->get_const_row_ptrs(); auto old_col_idxs = a->get_const_col_idxs(); - auto old_vals = a->get_const_values(); + auto old_vals = as_device_type(a->get_const_values()); // compute nnz for each row auto num_rows = static_cast(a->get_size()[0]); auto block_size = default_block_size / subgroup_size; diff --git a/dpcpp/factorization/par_ilut_select_kernel.dp.cpp b/dpcpp/factorization/par_ilut_select_kernel.dp.cpp index 589f8267f21..43c13fc730b 100644 --- a/dpcpp/factorization/par_ilut_select_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_select_kernel.dp.cpp @@ -61,7 +61,7 @@ void threshold_select(std::shared_ptr exec, array>& tmp2, remove_complex& threshold) { - auto values = m->get_const_values(); + auto values = as_device_type(m->get_const_values()); IndexType size = m->get_num_stored_elements(); using AbsType = remove_complex; constexpr auto bucket_count = kernel::searchtree_width; diff --git a/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp b/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp index 246228763bf..f9643fbe66b 100644 --- a/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp @@ -356,16 +356,16 @@ void add_candidates(syn::value_list, matrix::CsrBuilder u_new_builder(u_new); auto lu_row_ptrs = lu->get_const_row_ptrs(); auto lu_col_idxs = lu->get_const_col_idxs(); - auto lu_vals = lu->get_const_values(); + auto lu_vals = as_device_type(lu->get_const_values()); auto a_row_ptrs = a->get_const_row_ptrs(); auto a_col_idxs = a->get_const_col_idxs(); - auto a_vals = a->get_const_values(); + auto a_vals = as_device_type(a->get_const_values()); auto l_row_ptrs = l->get_const_row_ptrs(); auto l_col_idxs = l->get_const_col_idxs(); - auto l_vals = l->get_const_values(); + auto l_vals = as_device_type(l->get_const_values()); auto u_row_ptrs = u->get_const_row_ptrs(); auto u_col_idxs = u->get_const_col_idxs(); - auto u_vals = u->get_const_values(); + auto u_vals = as_device_type(u->get_const_values()); auto l_new_row_ptrs = l_new->get_row_ptrs(); auto u_new_row_ptrs = u_new->get_row_ptrs(); // count non-zeros per row diff --git a/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp b/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp index 601e5dc12d3..4644bb155d2 100644 --- a/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp @@ -176,12 +176,13 @@ void compute_l_u_factors(syn::value_list, auto num_blocks = ceildiv(total_nnz, block_size); kernel::sweep( num_blocks, default_block_size, 0, exec->get_queue(), - a->get_const_row_ptrs(), a->get_const_col_idxs(), a->get_const_values(), - l->get_const_row_ptrs(), l_coo->get_const_row_idxs(), - l->get_const_col_idxs(), l->get_values(), + a->get_const_row_ptrs(), a->get_const_col_idxs(), + as_device_type(a->get_const_values()), l->get_const_row_ptrs(), + l_coo->get_const_row_idxs(), l->get_const_col_idxs(), + as_device_type(l->get_values()), static_cast(l->get_num_stored_elements()), u_coo->get_const_row_idxs(), u_coo->get_const_col_idxs(), - u->get_values(), u_csc->get_const_row_ptrs(), + as_device_type(u->get_values()), u_csc->get_const_row_ptrs(), u_csc->get_const_col_idxs(), u_csc->get_values(), static_cast(u->get_num_stored_elements())); } From e9b24cb5f48fa30833dd527f787ca510a95cc7b0 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 13:31:23 +0200 Subject: [PATCH 7/7] solver/preconditioner/stop sycl type --- dpcpp/preconditioner/isai_kernels.dp.cpp | 32 ++++---- .../jacobi_advanced_apply_kernel.dp.cpp | 7 +- .../jacobi_generate_instantiate.inc.dp.cpp | 9 +- .../jacobi_simple_apply_kernel.dp.cpp | 4 +- dpcpp/solver/cb_gmres_kernels.dp.cpp | 24 +++--- dpcpp/solver/idr_kernels.dp.cpp | 82 +++++++++++-------- dpcpp/stop/residual_norm_kernels.dp.cpp | 4 +- 7 files changed, 91 insertions(+), 71 deletions(-) diff --git a/dpcpp/preconditioner/isai_kernels.dp.cpp b/dpcpp/preconditioner/isai_kernels.dp.cpp index 4082035ff9f..24e780f0619 100644 --- a/dpcpp/preconditioner/isai_kernels.dp.cpp +++ b/dpcpp/preconditioner/isai_kernels.dp.cpp @@ -626,16 +626,20 @@ void generate_tri_inverse(std::shared_ptr exec, kernel::generate_l_inverse( grid, block, 0, exec->get_queue(), static_cast(num_rows), input->get_const_row_ptrs(), - input->get_const_col_idxs(), input->get_const_values(), + input->get_const_col_idxs(), + as_device_type(input->get_const_values()), inverse->get_row_ptrs(), inverse->get_col_idxs(), - inverse->get_values(), excess_rhs_ptrs, excess_nz_ptrs); + as_device_type(inverse->get_values()), excess_rhs_ptrs, + excess_nz_ptrs); } else { kernel::generate_u_inverse( grid, block, 0, exec->get_queue(), static_cast(num_rows), input->get_const_row_ptrs(), - input->get_const_col_idxs(), input->get_const_values(), + input->get_const_col_idxs(), + as_device_type(input->get_const_values()), inverse->get_row_ptrs(), inverse->get_col_idxs(), - inverse->get_values(), excess_rhs_ptrs, excess_nz_ptrs); + as_device_type(inverse->get_values()), excess_rhs_ptrs, + excess_nz_ptrs); } } components::prefix_sum_nonnegative(exec, excess_rhs_ptrs, num_rows + 1); @@ -661,9 +665,9 @@ void generate_general_inverse(std::shared_ptr exec, kernel::generate_general_inverse( grid, block, 0, exec->get_queue(), static_cast(num_rows), input->get_const_row_ptrs(), input->get_const_col_idxs(), - input->get_const_values(), inverse->get_row_ptrs(), - inverse->get_col_idxs(), inverse->get_values(), excess_rhs_ptrs, - excess_nz_ptrs, spd); + as_device_type(input->get_const_values()), inverse->get_row_ptrs(), + inverse->get_col_idxs(), as_device_type(inverse->get_values()), + excess_rhs_ptrs, excess_nz_ptrs, spd); } components::prefix_sum_nonnegative(exec, excess_rhs_ptrs, num_rows + 1); components::prefix_sum_nonnegative(exec, excess_nz_ptrs, num_rows + 1); @@ -691,11 +695,11 @@ void generate_excess_system(std::shared_ptr exec, kernel::generate_excess_system( grid, block, 0, exec->get_queue(), static_cast(num_rows), input->get_const_row_ptrs(), input->get_const_col_idxs(), - input->get_const_values(), inverse->get_const_row_ptrs(), - inverse->get_const_col_idxs(), excess_rhs_ptrs, excess_nz_ptrs, - excess_system->get_row_ptrs(), excess_system->get_col_idxs(), - excess_system->get_values(), excess_rhs->get_values(), e_start, - e_end); + as_device_type(input->get_const_values()), + inverse->get_const_row_ptrs(), inverse->get_const_col_idxs(), + excess_rhs_ptrs, excess_nz_ptrs, excess_system->get_row_ptrs(), + excess_system->get_col_idxs(), excess_system->get_values(), + excess_rhs->get_values(), e_start, e_end); } } @@ -737,8 +741,8 @@ void scatter_excess_solution(std::shared_ptr exec, kernel::copy_excess_solution( grid, block, 0, exec->get_queue(), static_cast(num_rows), inverse->get_const_row_ptrs(), excess_rhs_ptrs, - excess_solution->get_const_values(), inverse->get_values(), e_start, - e_end); + excess_solution->get_const_values(), + as_device_type(inverse->get_values()), e_start, e_end); } } diff --git a/dpcpp/preconditioner/jacobi_advanced_apply_kernel.dp.cpp b/dpcpp/preconditioner/jacobi_advanced_apply_kernel.dp.cpp index 0e26989808e..72eb5aeefc6 100644 --- a/dpcpp/preconditioner/jacobi_advanced_apply_kernel.dp.cpp +++ b/dpcpp/preconditioner/jacobi_advanced_apply_kernel.dp.cpp @@ -59,9 +59,10 @@ void apply(std::shared_ptr exec, size_type num_blocks, syn::value_list(), syn::type_list<>(), exec, num_blocks, block_precisions.get_const_data(), block_pointers.get_const_data(), - blocks.get_const_data(), storage_scheme, alpha->get_const_values(), - b->get_const_values() + col, b->get_stride(), x->get_values() + col, - x->get_stride()); + blocks.get_const_data(), storage_scheme, + as_device_type(alpha->get_const_values()), + as_device_type(b->get_const_values()) + col, b->get_stride(), + as_device_type(x->get_values()) + col, x->get_stride()); } } diff --git a/dpcpp/preconditioner/jacobi_generate_instantiate.inc.dp.cpp b/dpcpp/preconditioner/jacobi_generate_instantiate.inc.dp.cpp index d957ea2c5be..d32bef36974 100644 --- a/dpcpp/preconditioner/jacobi_generate_instantiate.inc.dp.cpp +++ b/dpcpp/preconditioner/jacobi_generate_instantiate.inc.dp.cpp @@ -365,14 +365,15 @@ void generate(syn::value_list, warps_per_block>( grid_size, block_size, 0, exec->get_queue(), mtx->get_size()[0], mtx->get_const_row_ptrs(), mtx->get_const_col_idxs(), - mtx->get_const_values(), accuracy, block_data, storage_scheme, - conditioning, block_precisions, block_ptrs, num_blocks); + as_device_type(mtx->get_const_values()), accuracy, block_data, + storage_scheme, conditioning, block_precisions, block_ptrs, + num_blocks); } else { kernel::generate( grid_size, block_size, 0, exec->get_queue(), mtx->get_size()[0], mtx->get_const_row_ptrs(), mtx->get_const_col_idxs(), - mtx->get_const_values(), block_data, storage_scheme, block_ptrs, - num_blocks); + as_device_type(mtx->get_const_values()), block_data, storage_scheme, + block_ptrs, num_blocks); } } diff --git a/dpcpp/preconditioner/jacobi_simple_apply_kernel.dp.cpp b/dpcpp/preconditioner/jacobi_simple_apply_kernel.dp.cpp index 25701c6dc55..a0040caadc4 100644 --- a/dpcpp/preconditioner/jacobi_simple_apply_kernel.dp.cpp +++ b/dpcpp/preconditioner/jacobi_simple_apply_kernel.dp.cpp @@ -56,8 +56,8 @@ void simple_apply( syn::type_list<>(), exec, num_blocks, block_precisions.get_const_data(), block_pointers.get_const_data(), blocks.get_const_data(), storage_scheme, - b->get_const_values() + col, b->get_stride(), x->get_values() + col, - x->get_stride()); + as_device_type(b->get_const_values()) + col, b->get_stride(), + as_device_type(x->get_values()) + col, x->get_stride()); } } diff --git a/dpcpp/solver/cb_gmres_kernels.dp.cpp b/dpcpp/solver/cb_gmres_kernels.dp.cpp index 7ab010ba29f..d0e4f11bfee 100644 --- a/dpcpp/solver/cb_gmres_kernels.dp.cpp +++ b/dpcpp/solver/cb_gmres_kernels.dp.cpp @@ -939,11 +939,11 @@ void initialize(std::shared_ptr exec, initialize_kernel( grid_dim, block_dim, 0, exec->get_queue(), b->get_size()[0], - b->get_size()[1], krylov_dim, b->get_const_values(), b->get_stride(), - residual->get_values(), residual->get_stride(), - givens_sin->get_values(), givens_sin->get_stride(), - givens_cos->get_values(), givens_cos->get_stride(), - stop_status->get_data()); + b->get_size()[1], krylov_dim, as_device_type(b->get_const_values()), + b->get_stride(), as_device_type(residual->get_values()), + residual->get_stride(), givens_sin->get_values(), + givens_sin->get_stride(), givens_cos->get_values(), + givens_cos->get_stride(), stop_status->get_data()); } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CB_GMRES_INITIALIZE_KERNEL); @@ -990,7 +990,8 @@ void restart(std::shared_ptr exec, const dim3 block_size_nrm(default_dot_dim, default_dot_dim); multinorminf_without_stop_kernel( grid_size_nrm, block_size_nrm, 0, exec->get_queue(), num_rows, - num_rhs, residual->get_const_values(), residual->get_stride(), + num_rhs, as_device_type(residual->get_const_values()), + residual->get_stride(), arnoldi_norm->get_values() + 2 * stride_arnoldi, 0); } @@ -1009,7 +1010,7 @@ void restart(std::shared_ptr exec, 1, 1); restart_2_kernel( grid_dim_2, block_dim, 0, exec->get_queue(), residual->get_size()[0], - residual->get_size()[1], residual->get_const_values(), + residual->get_size()[1], as_device_type(residual->get_const_values()), residual->get_stride(), residual_norm->get_const_values(), residual_norm_collection->get_values(), krylov_bases, next_krylov_basis->get_values(), next_krylov_basis->get_stride(), @@ -1255,9 +1256,10 @@ void solve_upper_triangular( solve_upper_triangular_kernel( grid_dim, block_dim, 0, exec->get_queue(), hessenberg->get_size()[1], num_rhs, residual_norm_collection->get_const_values(), - residual_norm_collection->get_stride(), hessenberg->get_const_values(), - hessenberg->get_stride(), y->get_values(), y->get_stride(), - final_iter_nums->get_const_data()); + residual_norm_collection->get_stride(), + as_device_type(hessenberg->get_const_values()), + hessenberg->get_stride(), as_device_type(y->get_values()), + y->get_stride(), final_iter_nums->get_const_data()); } @@ -1283,7 +1285,7 @@ void calculate_qy(std::shared_ptr exec, calculate_Qy_kernel( grid_dim, block_dim, 0, exec->get_queue(), num_rows, num_cols, - krylov_bases, y->get_const_values(), y->get_stride(), + krylov_bases, as_device_type(y->get_const_values()), y->get_stride(), before_preconditioner->get_values(), stride_before_preconditioner, final_iter_nums->get_const_data()); // Calculate qy diff --git a/dpcpp/solver/idr_kernels.dp.cpp b/dpcpp/solver/idr_kernels.dp.cpp index d59ada362f9..c4bcedbe2bc 100644 --- a/dpcpp/solver/idr_kernels.dp.cpp +++ b/dpcpp/solver/idr_kernels.dp.cpp @@ -582,8 +582,8 @@ void initialize_m(std::shared_ptr exec, const auto grid_dim = ceildiv(m_stride * subspace_dim, default_block_size); initialize_m_kernel(grid_dim, default_block_size, 0, exec->get_queue(), - subspace_dim, nrhs, m->get_values(), m_stride, - stop_status->get_data()); + subspace_dim, nrhs, as_device_type(m->get_values()), + m_stride, stop_status->get_data()); } @@ -638,8 +638,9 @@ void solve_lower_triangular(std::shared_ptr exec, const auto grid_dim = ceildiv(nrhs, default_block_size); solve_lower_triangular_kernel( grid_dim, default_block_size, 0, exec->get_queue(), subspace_dim, nrhs, - m->get_const_values(), m->get_stride(), f->get_const_values(), - f->get_stride(), c->get_values(), c->get_stride(), + as_device_type(m->get_const_values()), m->get_stride(), + as_device_type(f->get_const_values()), f->get_stride(), + as_device_type(c->get_values()), c->get_stride(), stop_status->get_const_data()); } @@ -662,30 +663,34 @@ void update_g_and_u(std::shared_ptr exec, const dim3 block_dim(default_dot_dim, default_dot_dim); for (size_type i = 0; i < k; i++) { - const auto p_i = p->get_const_values() + i * p_stride; + const auto p_i = as_device_type(p->get_const_values()) + i * p_stride; if (nrhs > 1 || is_complex()) { - components::fill_array(exec, alpha->get_values(), nrhs, - zero()); + components::fill_array(exec, as_device_type(alpha->get_values()), + nrhs, zero()); multidot_kernel(grid_dim, block_dim, 0, exec->get_queue(), size, nrhs, p_i, g_k->get_values(), g_k->get_stride(), - alpha->get_values(), stop_status->get_const_data()); + as_device_type(alpha->get_values()), + stop_status->get_const_data()); } else { onemkl::dot(*exec->get_queue(), size, p_i, 1, g_k->get_values(), - g_k->get_stride(), alpha->get_values()); + g_k->get_stride(), as_device_type(alpha->get_values())); } update_g_k_and_u_kernel( ceildiv(size * g_k->get_stride(), default_block_size), default_block_size, 0, exec->get_queue(), k, i, size, nrhs, - alpha->get_const_values(), m->get_const_values(), m->get_stride(), - g->get_const_values(), g->get_stride(), g_k->get_values(), - g_k->get_stride(), u->get_values(), u->get_stride(), + as_device_type(alpha->get_const_values()), + as_device_type(m->get_const_values()), m->get_stride(), + as_device_type(g->get_const_values()), g->get_stride(), + g_k->get_values(), g_k->get_stride(), + as_device_type(u->get_values()), u->get_stride(), stop_status->get_const_data()); } update_g_kernel( ceildiv(size * g_k->get_stride(), default_block_size), default_block_size, 0, exec->get_queue(), k, size, nrhs, - g_k->get_const_values(), g_k->get_stride(), g->get_values(), - g->get_stride(), stop_status->get_const_data()); + g_k->get_const_values(), g_k->get_stride(), + as_device_type(g->get_values()), g->get_stride(), + stop_status->get_const_data()); } @@ -705,8 +710,8 @@ void update_m(std::shared_ptr exec, const size_type nrhs, const dim3 block_dim(default_dot_dim, default_dot_dim); for (size_type i = k; i < subspace_dim; i++) { - const auto p_i = p->get_const_values() + i * p_stride; - auto m_i = m->get_values() + i * m_stride + k * nrhs; + const auto p_i = as_device_type(p->get_const_values()) + i * p_stride; + auto m_i = as_device_type(m->get_values()) + i * m_stride + k * nrhs; if (nrhs > 1 || is_complex()) { components::fill_array(exec, m_i, nrhs, zero()); multidot_kernel(grid_dim, block_dim, 0, exec->get_queue(), size, @@ -735,15 +740,18 @@ void update_x_r_and_f(std::shared_ptr exec, const auto subspace_dim = m->get_size()[0]; const auto grid_dim = ceildiv(size * x->get_stride(), default_block_size); - update_x_r_and_f_kernel(grid_dim, default_block_size, 0, exec->get_queue(), - k, size, subspace_dim, nrhs, m->get_const_values(), - m->get_stride(), g->get_const_values(), - g->get_stride(), u->get_const_values(), - u->get_stride(), f->get_values(), f->get_stride(), - r->get_values(), r->get_stride(), x->get_values(), - x->get_stride(), stop_status->get_const_data()); - components::fill_array(exec, f->get_values() + k * f->get_stride(), nrhs, - zero()); + update_x_r_and_f_kernel( + grid_dim, default_block_size, 0, exec->get_queue(), k, size, + subspace_dim, nrhs, as_device_type(m->get_const_values()), + m->get_stride(), as_device_type(g->get_const_values()), g->get_stride(), + as_device_type(u->get_const_values()), u->get_stride(), + as_device_type(f->get_values()), f->get_stride(), + as_device_type(r->get_values()), r->get_stride(), + as_device_type(x->get_values()), x->get_stride(), + stop_status->get_const_data()); + components::fill_array( + exec, as_device_type(f->get_values()) + k * f->get_stride(), nrhs, + zero()); } @@ -780,11 +788,12 @@ void step_1(std::shared_ptr exec, const size_type nrhs, const auto grid_dim = ceildiv(nrhs * num_rows, default_block_size); step_1_kernel(grid_dim, default_block_size, 0, exec->get_queue(), k, - num_rows, subspace_dim, nrhs, residual->get_const_values(), - residual->get_stride(), c->get_const_values(), - c->get_stride(), g->get_const_values(), g->get_stride(), - v->get_values(), v->get_stride(), - stop_status->get_const_data()); + num_rows, subspace_dim, nrhs, + as_device_type(residual->get_const_values()), + residual->get_stride(), as_device_type(c->get_const_values()), + c->get_stride(), as_device_type(g->get_const_values()), + g->get_stride(), as_device_type(v->get_values()), + v->get_stride(), stop_status->get_const_data()); } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_STEP_1_KERNEL); @@ -805,10 +814,12 @@ void step_2(std::shared_ptr exec, const size_type nrhs, const auto grid_dim = ceildiv(nrhs * num_rows, default_block_size); step_2_kernel(grid_dim, default_block_size, 0, exec->get_queue(), k, - num_rows, subspace_dim, nrhs, omega->get_const_values(), + num_rows, subspace_dim, nrhs, + as_device_type(omega->get_const_values()), preconditioned_vector->get_const_values(), - preconditioned_vector->get_stride(), c->get_const_values(), - c->get_stride(), u->get_values(), u->get_stride(), + preconditioned_vector->get_stride(), + as_device_type(c->get_const_values()), c->get_stride(), + as_device_type(u->get_values()), u->get_stride(), stop_status->get_const_data()); } @@ -841,8 +852,9 @@ void compute_omega( { const auto grid_dim = ceildiv(nrhs, config::warp_size); compute_omega_kernel(grid_dim, config::warp_size, 0, exec->get_queue(), - nrhs, kappa, tht->get_const_values(), - residual_norm->get_const_values(), omega->get_values(), + nrhs, kappa, as_device_type(tht->get_const_values()), + residual_norm->get_const_values(), + as_device_type(omega->get_values()), stop_status->get_const_data()); } diff --git a/dpcpp/stop/residual_norm_kernels.dp.cpp b/dpcpp/stop/residual_norm_kernels.dp.cpp index ddb617a1a84..5de40b362f9 100644 --- a/dpcpp/stop/residual_norm_kernels.dp.cpp +++ b/dpcpp/stop/residual_norm_kernels.dp.cpp @@ -46,7 +46,7 @@ void residual_norm(std::shared_ptr exec, }); auto orig_tau_val = orig_tau->get_const_values(); - auto tau_val = tau->get_const_values(); + auto tau_val = as_device_type(tau->get_const_values()); auto stop_status_val = stop_status->get_data(); exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( @@ -102,7 +102,7 @@ void implicit_residual_norm( }); auto orig_tau_val = orig_tau->get_const_values(); - auto tau_val = tau->get_const_values(); + auto tau_val = as_device_type(tau->get_const_values()); auto stop_status_val = stop_status->get_data(); exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for(