Skip to content

Commit

Permalink
add kernel for take (#542)
Browse files Browse the repository at this point in the history
* add kernel for take
  • Loading branch information
Rubtsowa authored Jan 28, 2021
1 parent 113e6be commit f367dfe
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 6 deletions.
1 change: 1 addition & 0 deletions dpnp/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ set(DPNP_SRC
kernels/dpnp_krnl_common.cpp
kernels/dpnp_krnl_elemwise.cpp
kernels/dpnp_krnl_fft.cpp
kernels/dpnp_krnl_indexing.cpp
kernels/dpnp_krnl_linalg.cpp
kernels/dpnp_krnl_manipulation.cpp
kernels/dpnp_krnl_mathematical.cpp
Expand Down
13 changes: 13 additions & 0 deletions dpnp/backend/include/dpnp_iface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,19 @@ template <typename _DataType, typename _ResultType>
INP_DLLEXPORT void dpnp_std_c(
void* array, void* result, const size_t* shape, size_t ndim, const size_t* axis, size_t naxis, size_t ddof);

/**
* @ingroup BACKEND_API
* @brief math library implementation of take function
*
* @param [in] array Input array with data.
* @param [in] array Input array with indices.
* @param [out] result Output array with indeces.
* @param [in] size Number of elements in the input array.
*/
template <typename _DataType>
INP_DLLEXPORT void dpnp_take_c(
void* array, void* indices, void* result, size_t size);

/**
* @ingroup BACKEND_API
* @brief math library implementation of var function
Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() implementation */
DPNP_FN_SUM, /**< Used in numpy.sum() implementation */
DPNP_FN_SVD, /**< Used in numpy.linalg.svd() implementation */
DPNP_FN_TAKE, /**< Used in numpy.take() implementation */
DPNP_FN_TAN, /**< Used in numpy.tan() implementation */
DPNP_FN_TANH, /**< Used in numpy.tanh() implementation */
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() implementation */
Expand Down
63 changes: 63 additions & 0 deletions dpnp/backend/kernels/dpnp_krnl_indexing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//*****************************************************************************
// Copyright (c) 2016-2020, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <iostream>
#include <list>

#include <dpnp_iface.hpp>
#include "dpnp_fptr.hpp"
#include "dpnp_utils.hpp"
#include "queue_sycl.hpp"


template <typename _DataType>
class dpnp_take_c_kernel;

template <typename _DataType>
void dpnp_take_c(void* array1_in, void* indices1, void* result1, size_t size)
{
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
_DataType* result = reinterpret_cast<_DataType*>(result1);
size_t* indices = reinterpret_cast<size_t*>(indices1);

for (size_t i = 0; i < size; i++)
{
size_t ind = indices[i];
result[i] = array_1[ind];
}

return;
}


void func_map_init_indexing_func(func_map_t& fmap)
{
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_take_c<int>};
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_take_c<long>};
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_take_c<float>};
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_take_c<double>};

return;
}
1 change: 1 addition & 0 deletions dpnp/backend/src/dpnp_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ const DPNPFuncType eft_C128 = DPNPFuncType::DPNP_FT_CMPLX128;
void func_map_init_bitwise(func_map_t& fmap);
void func_map_init_elemwise(func_map_t& fmap);
void func_map_init_fft_func(func_map_t& fmap);
void func_map_init_indexing_func(func_map_t& fmap);
void func_map_init_linalg(func_map_t& fmap);
void func_map_init_linalg_func(func_map_t& fmap);
void func_map_init_manipulation(func_map_t& fmap);
Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/src/dpnp_iface_fptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ static func_map_t func_map_init()
func_map_init_bitwise(fmap);
func_map_init_elemwise(fmap);
func_map_init_fft_func(fmap);
func_map_init_indexing_func(fmap);
func_map_init_linalg(fmap);
func_map_init_linalg_func(fmap);
func_map_init_manipulation(fmap);
Expand Down
1 change: 1 addition & 0 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_SUBTRACT
DPNP_FN_SUM
DPNP_FN_SVD
DPNP_FN_TAKE
DPNP_FN_TAN
DPNP_FN_TANH
DPNP_FN_TRANSPOSE
Expand Down
22 changes: 16 additions & 6 deletions dpnp/dpnp_algo/dpnp_algo_indexing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ __all__ += [
]


ctypedef void(*custom_indexing_2in_1out_func_ptr_t)(void *, void * , void * , size_t)


cpdef dparray dpnp_choose(input, choices):
res_array = dparray(len(input), dtype=choices[0].dtype)
for i in range(len(input)):
Expand Down Expand Up @@ -259,12 +262,19 @@ cpdef dparray dpnp_select(condlist, choicelist, default):

cpdef dparray dpnp_take(dparray input, dparray indices):
indices_size = indices.size
res_array = dparray(indices_size, dtype=input.dtype)
for i in range(indices_size):
ind = indices[i]
res_array[i] = input[ind]
result = res_array.reshape(indices.shape)
return result

cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)

cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE, param1_type, param1_type)

result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
cdef dparray result = dparray(indices_size, dtype=result_type)

cdef custom_indexing_2in_1out_func_ptr_t func = <custom_indexing_2in_1out_func_ptr_t > kernel_data.ptr

func(input.get_data(), indices.get_data(), result.get_data(), indices_size)

return result.reshape(indices.shape)


cpdef tuple dpnp_tril_indices(n, k=0, m=None):
Expand Down

0 comments on commit f367dfe

Please sign in to comment.