Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dedicated kernels for in-place dpt.divide and dpt.floor_divide #1431

Merged
merged 3 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@
ti._divide_result_type,
ti._divide,
_divide_docstring_,
binary_inplace_fn=ti._divide_inplace,
acceptance_fn=_acceptance_fn_divide,
)

Expand Down Expand Up @@ -720,6 +721,7 @@
ti._floor_divide_result_type,
ti._floor_divide,
_floor_divide_docstring_,
binary_inplace_fn=ti._floor_divide_inplace,
)

# B11: ==== GREATER (x1, x2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ struct FloorDivideFunctor

resT operator()(const argT1 &in1, const argT2 &in2) const
{
if constexpr (std::is_same_v<argT1, bool> &&
std::is_same_v<argT2, bool>) {
return (in2) ? static_cast<resT>(in1) : resT(0);
}
else if constexpr (std::is_integral_v<argT1> ||
std::is_integral_v<argT2>) {
if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
if (in2 == argT2(0)) {
return resT(0);
}
Expand All @@ -87,16 +82,7 @@ struct FloorDivideFunctor
operator()(const sycl::vec<argT1, vec_sz> &in1,
const sycl::vec<argT2, vec_sz> &in2) const
{
if constexpr (std::is_same_v<argT1, bool> &&
std::is_same_v<argT2, bool>) {
sycl::vec<resT, vec_sz> res;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
res[i] = (in2[i]) ? static_cast<resT>(in1[i]) : resT(0);
}
return res;
}
else if constexpr (std::is_integral_v<resT>) {
if constexpr (std::is_integral_v<resT>) {
sycl::vec<resT, vec_sz> res;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
Expand Down Expand Up @@ -165,7 +151,6 @@ template <typename T1, typename T2> struct FloorDivideOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, std::int8_t>,
td_ns::BinaryTypeMapResultEntry<T1,
std::uint8_t,
T2,
Expand Down Expand Up @@ -315,6 +300,183 @@ struct FloorDivideStridedFactory
}
};

template <typename argT, typename resT> struct FloorDivideInplaceFunctor
{
using supports_sg_loadstore = std::true_type;
using supports_vec = std::true_type;

void operator()(resT &in1, const argT &in2) const
{
if constexpr (std::is_integral_v<resT>) {
if (in2 == argT(0)) {
in1 = 0;
return;
}
if constexpr (std::is_signed_v<resT>) {
auto tmp = in1;
in1 /= in2;
auto mod = tmp % in2;
auto corr = (mod != 0 && l_xor(mod < 0, in2 < 0));
in1 -= corr;
}
else {
in1 /= in2;
}
}
else {
in1 /= in2;
if (in1 == resT(0)) {
return;
}
in1 = std::floor(in1);
}
}

template <int vec_sz>
void operator()(sycl::vec<resT, vec_sz> &in1,
const sycl::vec<argT, vec_sz> &in2) const
{
if constexpr (std::is_integral_v<resT>) {
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
if (in2[i] == argT(0)) {
in1[i] = 0;
}
else {
if constexpr (std::is_signed_v<resT>) {
auto tmp = in1[i];
in1[i] /= in2[i];
auto mod = tmp % in2[i];
auto corr = (mod != 0 && l_xor(mod < 0, in2[i] < 0));
in1[i] -= corr;
}
else {
in1[i] /= in2[i];
}
}
}
}
else {
in1 /= in2;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
if (in2[i] != argT(0)) {
in1[i] = std::floor(in1[i]);
}
}
}
}

private:
bool l_xor(bool b1, bool b2) const
{
return (b1 != b2);
}
};

template <typename argT,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2>
using FloorDivideInplaceContigFunctor =
elementwise_common::BinaryInplaceContigFunctor<
argT,
resT,
FloorDivideInplaceFunctor<argT, resT>,
vec_sz,
n_vecs>;

template <typename argT, typename resT, typename IndexerT>
using FloorDivideInplaceStridedFunctor =
elementwise_common::BinaryInplaceStridedFunctor<
argT,
resT,
IndexerT,
FloorDivideInplaceFunctor<argT, resT>>;

template <typename argT,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class floor_divide_inplace_contig_kernel;

template <typename argTy, typename resTy>
sycl::event
floor_divide_inplace_contig_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg_p,
py::ssize_t arg_offset,
char *res_p,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_inplace_contig_impl<
argTy, resTy, FloorDivideInplaceContigFunctor,
floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
res_p, res_offset, depends);
}

template <typename fnT, typename T1, typename T2>
struct FloorDivideInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename FloorDivideOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = floor_divide_inplace_contig_impl<T1, T2>;
return fn;
}
}
};

template <typename resT, typename argT, typename IndexerT>
class floor_divide_inplace_strided_kernel;

template <typename argTy, typename resTy>
sycl::event floor_divide_inplace_strided_impl(
sycl::queue &exec_q,
size_t nelems,
int nd,
const py::ssize_t *shape_and_strides,
const char *arg_p,
py::ssize_t arg_offset,
char *res_p,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
return elementwise_common::binary_inplace_strided_impl<
argTy, resTy, FloorDivideInplaceStridedFunctor,
floor_divide_inplace_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2>
struct FloorDivideInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename FloorDivideOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = floor_divide_inplace_strided_impl<T1, T2>;
return fn;
}
}
};

} // namespace floor_divide
} // namespace kernels
} // namespace tensor
Expand Down
Loading
Loading