diff --git a/help_function/src/onedpl_test_device_ptr.cpp b/help_function/src/onedpl_test_device_ptr.cpp index 2ec4aad27..33652cf42 100644 --- a/help_function/src/onedpl_test_device_ptr.cpp +++ b/help_function/src/onedpl_test_device_ptr.cpp @@ -19,16 +19,38 @@ #include -template -int ASSERT_EQUAL(String msg, _T1&& X, _T2&& Y) { - if(X!=Y) { +template +int +ASSERT_EQUAL(String msg, _T1&& X, _T2&& Y, bool skip_pass_msg = false) +{ + if (X != Y) + { std::cout << "FAIL: " << msg << " - (" << X << "," << Y << ")" << std::endl; return 1; } - else { + else if (!skip_pass_msg) + { std::cout << "PASS: " << msg << std::endl; - return 0; } + return 0; +} + +template +int +ASSERT_EQUAL_N(String msg, _T1&& X, _T2&& Y, ::std::size_t n) +{ + int failed_tests = 0; + for (size_t i = 0; i < n; i++) + { + failed_tests += ASSERT_EQUAL(msg, *X, *Y, true); + X++; + Y++; + } + if (failed_tests == 0) + { + std::cout << "PASS: " << msg << std::endl; + } + return failed_tests; } int test_device_ptr_manipulation(void) @@ -115,12 +137,48 @@ int test_device_ptr_iteration(void) return failing_tests; } -int main() { - int failed_tests = test_device_ptr_manipulation(); +int +test_permutation_iterator() +{ + int failing_tests = 0; + typedef size_t T; +#ifdef DPCT_USM_LEVEL_NONE + sycl::buffer data(sycl::range<1>(1024)); + + dpct::device_pointer begin(data, 0); + dpct::device_pointer end(data, 1024); + + sycl::buffer data_res(sycl::range<1>(1024)); + dpct::device_pointer begin_res(data_res, 0); + dpct::device_pointer end_res(data_res, 1024); +#else + dpct::device_pointer data(1024 * sizeof(T)); + dpct::device_pointer begin(data); + dpct::device_pointer end(data + 1024); + + dpct::device_pointer data_res(1024 * sizeof(T)); + dpct::device_pointer begin_res(data_res); + dpct::device_pointer end_res(data_res + 1024); +#endif + auto policy = oneapi::dpl::execution::make_device_policy(dpct::get_default_queue()); + std::fill(begin, end, T(1)); + std::fill(begin_res, end_res, T(99)); + auto perm = oneapi::dpl::make_permutation_iterator(begin, oneapi::dpl::counting_iterator(0)); + std::copy(policy, perm, perm + 1024, begin_res); + return ASSERT_EQUAL_N("device_ptr in permutation_iterator", begin_res, dpct::make_constant_iterator(T(1)), 1024); +} + +int +main() +{ + int failed_tests = 0; + failed_tests += test_device_ptr_manipulation(); failed_tests += test_device_ptr_iteration(); + failed_tests += test_permutation_iterator(); std::cout << std::endl << failed_tests << " failing test(s) detected." << std::endl; - if (failed_tests == 0) { + if (failed_tests == 0) + { return 0; } return 1;