Skip to content

Commit

Permalink
Merge pull request #1366 from IntelPython/create-multi_ptr-per-sycl-2…
Browse files Browse the repository at this point in the history
…020-standard

Conversion from raw to multi_ptr should be done with address_space_cast
  • Loading branch information
oleksandr-pavlyk committed Aug 25, 2023
2 parents d85e130 + f7eee1e commit 8eab04b
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 145 deletions.
21 changes: 11 additions & 10 deletions dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,25 +244,26 @@ class ContigCopyFunctor

if (base + n_vecs * vec_sz * sgSize < nelems &&
sgSize == max_sgSize) {
using src_ptrT =
sycl::multi_ptr<const srcT,
sycl::access::address_space::global_space>;
using dst_ptrT =
sycl::multi_ptr<dstT,
sycl::access::address_space::global_space>;
sycl::vec<srcT, vec_sz> src_vec;
sycl::vec<dstT, vec_sz> dst_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
src_vec =
sg.load<vec_sz>(src_ptrT(&src_p[base + it * sgSize]));
auto src_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(
&src_p[base + it * sgSize]);
auto dst_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(
&dst_p[base + it * sgSize]);

src_vec = sg.load<vec_sz>(src_multi_ptr);
#pragma unroll
for (std::uint8_t k = 0; k < vec_sz; k++) {
dst_vec[k] = fn(src_vec[k]);
}
sg.store<vec_sz>(dst_ptrT(&dst_p[base + it * sgSize]),
dst_vec);
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ struct UnaryContigFunctor
if constexpr (UnaryOperatorT::is_constant::value) {
// value of operator is known to be a known constant
constexpr resT const_val = UnaryOperatorT::constant_value;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

auto sg = ndit.get_sub_group();
std::uint8_t sgSize = sg.get_local_range()[0];
Expand All @@ -80,8 +77,11 @@ struct UnaryContigFunctor
sycl::vec<resT, vec_sz> res_vec(const_val);
#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand All @@ -94,13 +94,6 @@ struct UnaryContigFunctor
else if constexpr (UnaryOperatorT::supports_sg_loadstore::value &&
UnaryOperatorT::supports_vec::value)
{
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

auto sg = ndit.get_sub_group();
std::uint16_t sgSize = sg.get_local_range()[0];
std::uint16_t max_sgSize = sg.get_max_local_range()[0];
Expand All @@ -113,10 +106,16 @@ struct UnaryContigFunctor

#pragma unroll
for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
x = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
auto in_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

x = sg.load<vec_sz>(in_multi_ptr);
sycl::vec<resT, vec_sz> res_vec = op(x);
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand All @@ -141,23 +140,23 @@ struct UnaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(maxsgSize == sgSize)) {
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT, vec_sz> arg_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
auto in_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg_vec = sg.load<vec_sz>(in_multi_ptr);
#pragma unroll
for (std::uint8_t k = 0; k < vec_sz; ++k) {
arg_vec[k] = op(arg_vec[k]);
}
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
arg_vec);
sg.store<vec_sz>(out_multi_ptr, arg_vec);
}
}
else {
Expand All @@ -179,24 +178,24 @@ struct UnaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(maxsgSize == sgSize)) {
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT, vec_sz> arg_vec;
sycl::vec<resT, vec_sz> res_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
auto in_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg_vec = sg.load<vec_sz>(in_multi_ptr);
#pragma unroll
for (std::uint8_t k = 0; k < vec_sz; ++k) {
res_vec[k] = op(arg_vec[k]);
}
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand Down Expand Up @@ -365,28 +364,26 @@ struct BinaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(sgSize == maxsgSize)) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT1, vec_sz> arg1_vec;
sycl::vec<argT2, vec_sz> arg2_vec;
sycl::vec<resT, vec_sz> res_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg1_vec =
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
arg2_vec =
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in1[base + it * sgSize]);
auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in2[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg1_vec = sg.load<vec_sz>(in1_multi_ptr);
arg2_vec = sg.load<vec_sz>(in2_multi_ptr);
res_vec = op(arg1_vec, arg2_vec);
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand All @@ -407,32 +404,30 @@ struct BinaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(sgSize == maxsgSize)) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT1, vec_sz> arg1_vec;
sycl::vec<argT2, vec_sz> arg2_vec;
sycl::vec<resT, vec_sz> res_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg1_vec =
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
arg2_vec =
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in1[base + it * sgSize]);
auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in2[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg1_vec = sg.load<vec_sz>(in1_multi_ptr);
arg2_vec = sg.load<vec_sz>(in2_multi_ptr);
#pragma unroll
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
res_vec[vec_id] =
op(arg1_vec[vec_id], arg2_vec[vec_id]);
}
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand Down Expand Up @@ -530,22 +525,24 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
size_t base = gid - sg.get_local_id()[0];

if (base + sgSize < n_elems) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using res_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&mat[base]);

auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&padded_vec[base % n1]);

auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&res[base]);

const argT1 mat_el = sg.load(in1_multi_ptr);
const argT2 vec_el = sg.load(in2_multi_ptr);

resT res_el = op(mat_el, vec_el);

sg.store(res_ptrT(&res[base]), res_el);
sg.store(out_multi_ptr, res_el);
}
else {
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
Expand Down Expand Up @@ -592,22 +589,24 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
size_t base = gid - sg.get_local_id()[0];

if (base + sgSize < n_elems) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using res_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

const argT2 mat_el = sg.load(in_ptrT2(&mat[base]));
const argT1 vec_el = sg.load(in_ptrT1(&padded_vec[base % n1]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&padded_vec[base % n1]);

auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&mat[base]);

auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&res[base]);

const argT2 mat_el = sg.load(in2_multi_ptr);
const argT1 vec_el = sg.load(in1_multi_ptr);

resT res_el = op(vec_el, mat_el);

sg.store(res_ptrT(&res[base]), res_el);
sg.store(out_multi_ptr, res_el);
}
else {
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
Expand Down
Loading

0 comments on commit 8eab04b

Please sign in to comment.