Skip to content

Commit

Permalink
Merge pull request #1730 from IntelPython/dpctl-tensor-nextafter
Browse files Browse the repository at this point in the history
Implements `dpctl.tensor.nextafter` per array API
  • Loading branch information
oleksandr-pavlyk authored Jul 19, 2024
2 parents 28a231e + e5f9810 commit 477946e
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 0 deletions.
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
minimum,
multiply,
negative,
nextafter,
not_equal,
positive,
pow,
Expand Down Expand Up @@ -371,4 +372,5 @@
"cumulative_logsumexp",
"cumulative_prod",
"cumulative_sum",
"nextafter",
]
34 changes: 34 additions & 0 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,40 @@
)
del _negative_docstring_

# B28: ==== NEXTAFTER (x1, x2)
_nextafter_docstring_ = r"""
nextafter(x1, x2, /, \*, out=None, order='K')
Calculates the next floating-point value after element `x1_i` of the input
array `x1` toward the respective element `x2_i` of the input array `x2`.
Args:
x1 (usm_ndarray):
First input array.
x2 (usm_ndarray):
Second input array.
out (Union[usm_ndarray, None], optional):
Output array to populate.
Array must have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the new output array, if parameter
`out` is ``None``.
Default: "K".
Returns:
usm_ndarray:
An array containing the element-wise next representable values of `x1`
in the direction of `x2`. The data type of the returned array is
determined by the Type Promotion Rules.
"""
nextafter = BinaryElementwiseFunc(
"nextafter",
ti._nextafter_result_type,
ti._nextafter,
_nextafter_docstring_,
)
del _nextafter_docstring_

# B20: ==== NOT_EQUAL (x1, x2)
_not_equal_docstring_ = r"""
not_equal(x1, x2, /, \*, out=None, order='K')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
//=== NEXTAFTER.hpp - Binary function NEXTAFTER ------ *-C++-*--/===//
//
// Data Parallel Control (dpctl)
//
// Copyright 2020-2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===---------------------------------------------------------------------===//
///
/// \file
/// This file defines kernels for elementwise evaluation of NEXTAFTER(x1, x2)
/// function.
//===---------------------------------------------------------------------===//

#pragma once
#include <cstddef>
#include <cstdint>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch_building.hpp"
#include "utils/type_utils.hpp"

#include "kernels/dpctl_tensor_types.hpp"
#include "kernels/elementwise_functions/common.hpp"

namespace dpctl
{
namespace tensor
{
namespace kernels
{
namespace nextafter
{

namespace td_ns = dpctl::tensor::type_dispatch;
namespace tu_ns = dpctl::tensor::type_utils;

template <typename argT1, typename argT2, typename resT> struct NextafterFunctor
{

using supports_sg_loadstore = std::true_type;
using supports_vec = std::true_type;

resT operator()(const argT1 &in1, const argT2 &in2) const
{
return sycl::nextafter(in1, in2);
}

template <int vec_sz>
sycl::vec<resT, vec_sz>
operator()(const sycl::vec<argT1, vec_sz> &in1,
const sycl::vec<argT2, vec_sz> &in2) const
{
auto res = sycl::nextafter(in1, in2);
if constexpr (std::is_same_v<resT,
typename decltype(res)::element_type>) {
return res;
}
else {
using dpctl::tensor::type_utils::vec_cast;

return vec_cast<resT, typename decltype(res)::element_type, vec_sz>(
res);
}
}
};

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using NextafterContigFunctor = elementwise_common::BinaryContigFunctor<
argT1,
argT2,
resT,
NextafterFunctor<argT1, argT2, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT1, typename argT2, typename resT, typename IndexerT>
using NextafterStridedFunctor = elementwise_common::BinaryStridedFunctor<
argT1,
argT2,
resT,
IndexerT,
NextafterFunctor<argT1, argT2, resT>>;

template <typename T1, typename T2> struct NextafterOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
td_ns::BinaryTypeMapResultEntry<T1,
sycl::half,
T2,
sycl::half,
sycl::half>,
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
td_ns::DefaultResultEntry<void>>::result_type;
};

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class nextafter_contig_kernel;

template <typename argTy1, typename argTy2>
sycl::event nextafter_contig_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg1_p,
ssize_t arg1_offset,
const char *arg2_p,
ssize_t arg2_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_contig_impl<
argTy1, argTy2, NextafterOutputType, NextafterContigFunctor,
nextafter_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends);
}

template <typename fnT, typename T1, typename T2> struct NextafterContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename NextafterOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = nextafter_contig_impl<T1, T2>;
return fn;
}
}
};

template <typename fnT, typename T1, typename T2> struct NextafterTypeMapFactory
{
/*! @brief get typeid for output type of std::nextafter(T1 x, T2 y) */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
using rT = typename NextafterOutputType<T1, T2>::value_type;
;
return td_ns::GetTypeid<rT>{}.get();
}
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class nextafter_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
nextafter_strided_impl(sycl::queue &exec_q,
size_t nelems,
int nd,
const ssize_t *shape_and_strides,
const char *arg1_p,
ssize_t arg1_offset,
const char *arg2_p,
ssize_t arg2_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, NextafterOutputType, NextafterStridedFunctor,
nextafter_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct NextafterStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename NextafterOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = nextafter_strided_impl<T1, T2>;
return fn;
}
}
};

} // namespace nextafter
} // namespace kernels
} // namespace tensor
} // namespace dpctl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include "minimum.hpp"
#include "multiply.hpp"
#include "negative.hpp"
#include "nextafter.hpp"
#include "not_equal.hpp"
#include "positive.hpp"
#include "pow.hpp"
Expand Down Expand Up @@ -158,6 +159,7 @@ void init_elementwise_functions(py::module_ m)
init_maximum(m);
init_minimum(m);
init_multiply(m);
init_nextafter(m);
init_negative(m);
init_not_equal(m);
init_positive(m);
Expand Down
Loading

0 comments on commit 477946e

Please sign in to comment.