Skip to content

Commit

Permalink
Merge branch 'master' into nd-support-to-trim_zero
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy authored Dec 18, 2024
2 parents 303ec8d + ea718e3 commit ee19ab4
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@
using dpctl::tensor::kernels::alignment_utils::is_aligned;
using dpctl::tensor::kernels::alignment_utils::required_alignment;

using sycl::ext::oneapi::experimental::group_load;
using sycl::ext::oneapi::experimental::group_store;
namespace syclex = sycl::ext::oneapi::experimental;
using syclex::group_load;
using syclex::group_store;

constexpr auto striped = syclex::properties{syclex::data_placement_striped};

template <typename T>
constexpr T dispatch_erf_op(T elem)
Expand Down Expand Up @@ -529,8 +532,8 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
sycl::vec<_DataType_input1, vec_sz> x1{}; \
sycl::vec<_DataType_input2, vec_sz> x2{}; \
\
group_load(sg, input1_multi_ptr, x1); \
group_load(sg, input2_multi_ptr, x2); \
group_load(sg, input1_multi_ptr, x1, striped); \
group_load(sg, input2_multi_ptr, x2, striped); \
\
res_vec = __vec_operation__; \
} \
Expand All @@ -540,8 +543,10 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \
sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \
\
group_load(sg, input1_multi_ptr, tmp_x1); \
group_load(sg, input2_multi_ptr, tmp_x2); \
group_load(sg, input1_multi_ptr, tmp_x1, \
striped); \
group_load(sg, input2_multi_ptr, tmp_x2, \
striped); \
\
sycl::vec<_DataType_output, vec_sz> x1 = \
dpnp_vec_cast<_DataType_output, \
Expand All @@ -559,16 +564,16 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
sycl::vec<_DataType_input1, vec_sz> x1{}; \
sycl::vec<_DataType_input2, vec_sz> x2{}; \
\
group_load(sg, input1_multi_ptr, x1); \
group_load(sg, input2_multi_ptr, x2); \
group_load(sg, input1_multi_ptr, x1, striped); \
group_load(sg, input2_multi_ptr, x2, striped); \
\
for (size_t k = 0; k < vec_sz; ++k) { \
const _DataType_output input1_elem = x1[k]; \
const _DataType_output input2_elem = x2[k]; \
res_vec[k] = __operation__; \
} \
} \
group_store(sg, res_vec, result_multi_ptr); \
group_store(sg, res_vec, result_multi_ptr, striped); \
} \
else { \
for (size_t k = start + sg.get_local_id()[0]; \
Expand Down

0 comments on commit ee19ab4

Please sign in to comment.