Skip to content

Commit

Permalink
[SYCL] Stop reinterpret from changing buffer_allocator type to const (#…
Browse files Browse the repository at this point in the history
…6769)

This PR prevents the `buffer_allocator` from being rebound to a const
type. Currently when a `buffer` is reinterpreted from type `T` to type
`const T` the `buffer_allocator` type is also changed to `const T`. This
does not agree with the SYCL spec which states "A buffer of data type
const T uses buffer_allocator<T> by default."

This PR also adds a compile time test which validates the returned type.
  • Loading branch information
AidanBeltonS authored Oct 17, 2022
1 parent a3a88bc commit 3aabd26
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sycl/include/sycl/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ class buffer : public detail::buffer_plain {
template <typename ReinterpretT, int ReinterpretDim>
buffer<ReinterpretT, ReinterpretDim,
typename std::allocator_traits<AllocatorT>::template rebind_alloc<
ReinterpretT>>
std::remove_const_t<ReinterpretT>>>
reinterpret(range<ReinterpretDim> reinterpretRange) const {
if (sizeof(ReinterpretT) * reinterpretRange.size() != byte_size())
throw sycl::invalid_object_error(
Expand All @@ -637,8 +637,8 @@ class buffer : public detail::buffer_plain {
PI_ERROR_INVALID_VALUE);

return buffer<ReinterpretT, ReinterpretDim,
typename std::allocator_traits<
AllocatorT>::template rebind_alloc<ReinterpretT>>(
typename std::allocator_traits<AllocatorT>::
template rebind_alloc<std::remove_const_t<ReinterpretT>>>(
impl, reinterpretRange, OffsetInBytes, IsSubBuffer);
}

Expand All @@ -647,11 +647,11 @@ class buffer : public detail::buffer_plain {
(sizeof(ReinterpretT) == sizeof(T)) && (dimensions == ReinterpretDim),
buffer<ReinterpretT, ReinterpretDim,
typename std::allocator_traits<AllocatorT>::template rebind_alloc<
ReinterpretT>>>::type
std::remove_const_t<ReinterpretT>>>>::type
reinterpret() const {
return buffer<ReinterpretT, ReinterpretDim,
typename std::allocator_traits<
AllocatorT>::template rebind_alloc<ReinterpretT>>(
typename std::allocator_traits<AllocatorT>::
template rebind_alloc<std::remove_const_t<ReinterpretT>>>(
impl, get_range(), OffsetInBytes, IsSubBuffer);
}

Expand Down
70 changes: 70 additions & 0 deletions sycl/test/basic_tests/buffer/buffer_reinterpret.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: %clangxx -fsycl -fsyntax-only %s

#include <sycl/sycl.hpp>

template <int Dimensions> sycl::range<Dimensions> create_range() {
return sycl::range<Dimensions>(1);
}

template <> sycl::range<2> create_range() { return sycl::range<2>(1, 1); }

template <> sycl::range<3> create_range() { return sycl::range<3>(1, 1, 1); }

// Compile only test to check that buffer_allocator type does not get
// reinterpreted with const keyworkd
template <typename T, int Dimensions,
typename Allocator = sycl::buffer_allocator<T>>
void test_buffer_const_reinterpret() {
sycl::buffer<T, Dimensions, Allocator> buff(create_range<Dimensions>());
sycl::buffer<T const, Dimensions, Allocator> const_buff(
create_range<Dimensions>());

auto reinterpret_buff = buff.template reinterpret<T const, Dimensions>(
create_range<Dimensions>());

static_assert(
std::is_same_v<decltype(const_buff), decltype(reinterpret_buff)>);
}

struct my_struct {
int my_int = 0;
float my_float = 0;
double my_double = 0;
};

int main() {
test_buffer_const_reinterpret<short, 1>();
test_buffer_const_reinterpret<int, 1>();
test_buffer_const_reinterpret<long, 1>();
test_buffer_const_reinterpret<unsigned short, 1>();
test_buffer_const_reinterpret<unsigned int, 1>();
test_buffer_const_reinterpret<unsigned long, 1>();
test_buffer_const_reinterpret<sycl::half, 1>();
test_buffer_const_reinterpret<float, 1>();
test_buffer_const_reinterpret<double, 1>();
test_buffer_const_reinterpret<my_struct, 1>();

test_buffer_const_reinterpret<short, 2>();
test_buffer_const_reinterpret<int, 2>();
test_buffer_const_reinterpret<long, 2>();
test_buffer_const_reinterpret<unsigned short, 2>();
test_buffer_const_reinterpret<unsigned int, 2>();
test_buffer_const_reinterpret<unsigned long, 2>();
test_buffer_const_reinterpret<sycl::half, 2>();
test_buffer_const_reinterpret<float, 2>();
test_buffer_const_reinterpret<double, 2>();
test_buffer_const_reinterpret<my_struct, 2>();

test_buffer_const_reinterpret<short, 3>();
test_buffer_const_reinterpret<int, 3>();
test_buffer_const_reinterpret<long, 3>();
test_buffer_const_reinterpret<unsigned short, 3>();
test_buffer_const_reinterpret<unsigned int, 3>();
test_buffer_const_reinterpret<unsigned long, 3>();
test_buffer_const_reinterpret<sycl::half, 3>();
test_buffer_const_reinterpret<float, 3>();
test_buffer_const_reinterpret<double, 3>();
test_buffer_const_reinterpret<my_struct, 3>();

return 0;
}

0 comments on commit 3aabd26

Please sign in to comment.