From 3aabd2640386849134d54f8021d1b74dfb5932ea Mon Sep 17 00:00:00 2001 From: AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> Date: Mon, 17 Oct 2022 12:04:24 +0100 Subject: [PATCH] [SYCL] Stop reinterpret from changing buffer_allocator type to const (#6769) This PR prevents the `buffer_allocator` from being rebound to a const type. Currently when a `buffer` is reinterpreted from type `T` to type `const T` the `buffer_allocator` type is also changed to `const T`. This does not agree with the SYCL spec which states "A buffer of data type const T uses buffer_allocator by default." This PR also adds a compile time test which validates the returned type. --- sycl/include/sycl/buffer.hpp | 12 ++-- .../basic_tests/buffer/buffer_reinterpret.cpp | 70 +++++++++++++++++++ 2 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 sycl/test/basic_tests/buffer/buffer_reinterpret.cpp diff --git a/sycl/include/sycl/buffer.hpp b/sycl/include/sycl/buffer.hpp index 3225f7827737c..ffcfcc9802677 100644 --- a/sycl/include/sycl/buffer.hpp +++ b/sycl/include/sycl/buffer.hpp @@ -627,7 +627,7 @@ class buffer : public detail::buffer_plain { template buffer::template rebind_alloc< - ReinterpretT>> + std::remove_const_t>> reinterpret(range reinterpretRange) const { if (sizeof(ReinterpretT) * reinterpretRange.size() != byte_size()) throw sycl::invalid_object_error( @@ -637,8 +637,8 @@ class buffer : public detail::buffer_plain { PI_ERROR_INVALID_VALUE); return buffer::template rebind_alloc>( + typename std::allocator_traits:: + template rebind_alloc>>( impl, reinterpretRange, OffsetInBytes, IsSubBuffer); } @@ -647,11 +647,11 @@ class buffer : public detail::buffer_plain { (sizeof(ReinterpretT) == sizeof(T)) && (dimensions == ReinterpretDim), buffer::template rebind_alloc< - ReinterpretT>>>::type + std::remove_const_t>>>::type reinterpret() const { return buffer::template rebind_alloc>( + typename std::allocator_traits:: + template rebind_alloc>>( impl, get_range(), OffsetInBytes, IsSubBuffer); } diff --git a/sycl/test/basic_tests/buffer/buffer_reinterpret.cpp b/sycl/test/basic_tests/buffer/buffer_reinterpret.cpp new file mode 100644 index 0000000000000..c6a12538bbb05 --- /dev/null +++ b/sycl/test/basic_tests/buffer/buffer_reinterpret.cpp @@ -0,0 +1,70 @@ +// RUN: %clangxx -fsycl -fsyntax-only %s + +#include + +template sycl::range create_range() { + return sycl::range(1); +} + +template <> sycl::range<2> create_range() { return sycl::range<2>(1, 1); } + +template <> sycl::range<3> create_range() { return sycl::range<3>(1, 1, 1); } + +// Compile only test to check that buffer_allocator type does not get +// reinterpreted with const keyworkd +template > +void test_buffer_const_reinterpret() { + sycl::buffer buff(create_range()); + sycl::buffer const_buff( + create_range()); + + auto reinterpret_buff = buff.template reinterpret( + create_range()); + + static_assert( + std::is_same_v); +} + +struct my_struct { + int my_int = 0; + float my_float = 0; + double my_double = 0; +}; + +int main() { + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + test_buffer_const_reinterpret(); + + return 0; +}