Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp fft implementations to run on Iris Xe #1524

Merged
merged 12 commits into from
Aug 21, 2023
44 changes: 30 additions & 14 deletions dpnp/backend/kernels/dpnp_krnl_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,12 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
{
double *array1_copy = reinterpret_cast<double *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
using CastType = std::conditional_t<
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
std::is_same<_DataType_output, std::complex<double>>::value,
double, float>;
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

CastType *array1_copy = reinterpret_cast<CastType *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType)));

shape_elem_type *copy_strides = reinterpret_cast<shape_elem_type *>(
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
Expand All @@ -486,15 +490,17 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
*copy_shape = input_size;
shape_elem_type copy_shape_size = 1;
event_ref = dpnp_copyto_c<_DataType_input, double>(
event_ref = dpnp_copyto_c<_DataType_input, CastType>(
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
copy_strides, array1_in, input_size, copy_shape_size,
copy_shape, copy_strides, NULL, dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double,
desc_dp_real_t>(
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
CastType, CastType,
std::conditional_t<std::is_same<CastType, double>::value,
desc_dp_real_t, desc_sp_real_t>>(
q_ref, array1_copy, result_out, input_shape, result_shape,
shape_size, input_size, result_size, inverse, norm, 0);

Expand Down Expand Up @@ -617,8 +623,12 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
{
double *array1_copy = reinterpret_cast<double *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
using CastType = std::conditional_t<
std::is_same<_DataType_output, std::complex<double>>::value,
double, float>;

CastType *array1_copy = reinterpret_cast<CastType *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType)));

shape_elem_type *copy_strides = reinterpret_cast<shape_elem_type *>(
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
Expand All @@ -627,15 +637,17 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
*copy_shape = input_size;
shape_elem_type copy_shape_size = 1;
event_ref = dpnp_copyto_c<_DataType_input, double>(
event_ref = dpnp_copyto_c<_DataType_input, CastType>(
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
copy_strides, array1_in, input_size, copy_shape_size,
copy_shape, copy_strides, NULL, dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double,
desc_dp_real_t>(
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
CastType, CastType,
std::conditional_t<std::is_same<CastType, double>::value,
desc_dp_real_t, desc_sp_real_t>>(
q_ref, array1_copy, result_out, input_shape, result_shape,
shape_size, input_size, result_size, inverse, norm, 1);

Expand Down Expand Up @@ -721,9 +733,11 @@ void func_map_init_fft_func(func_map_t &fmap)
dpnp_fft_fft_default_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_INT][eft_INT] = {
eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_LNG][eft_LNG] = {
eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_FLT][eft_FLT] = {
eft_C64, (void *)dpnp_fft_fft_ext_c<float, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -748,9 +762,11 @@ void func_map_init_fft_func(func_map_t &fmap)
(void *)dpnp_fft_rfft_default_c<double, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_INT][eft_INT] = {
eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_LNG][eft_LNG] = {
eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_FLT][eft_FLT] = {
eft_C64, (void *)dpnp_fft_rfft_ext_c<float, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_DBL][eft_DBL] = {
Expand Down
18 changes: 14 additions & 4 deletions dpnp/fft/dpnp_algo_fft.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,

input_obj = input.get_array()

# get FPTR function and return type
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
input_obj.sycl_device.has_aspect_fp64)
cdef DPNPFuncType return_type = ret_type_and_func[0]
cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1]

# ceate result array with type given by FPTR data
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape,
kernel_data.return_type,
return_type,
None,
device=input_obj.sycl_device,
usm_type=input_obj.usm_type,
Expand All @@ -81,7 +87,6 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
# call FPTR function
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
input.get_data(),
Expand Down Expand Up @@ -122,9 +127,15 @@ cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input,

input_obj = input.get_array()

# get FPTR function and return type
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
input_obj.sycl_device.has_aspect_fp64)
cdef DPNPFuncType return_type = ret_type_and_func[0]
cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1]

# ceate result array with type given by FPTR data
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape,
kernel_data.return_type,
return_type,
None,
device=input_obj.sycl_device,
usm_type=input_obj.usm_type,
Expand All @@ -135,7 +146,6 @@ cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input,
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
# call FPTR function
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
input.get_data(),
Expand Down
Loading