Skip to content

Commit

Permalink
more fixes and Anonymous space/const fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Jul 10, 2024
1 parent afadc55 commit 58ea2f3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
7 changes: 6 additions & 1 deletion core/src/Kokkos_View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {
using device_type = typename traits::device_type;
using pointer_type = typename traits::value_type*;
using memory_traits = typename traits::memory_traits;
using host_mirror_space = typename traits::host_mirror_space;

// typedefs from BasicView
using mdspan_type = typename base_t::mdspan_type;
Expand Down Expand Up @@ -611,7 +612,11 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {
// Memory span required to wrap these dimensions.
static constexpr size_t required_allocation_size(
typename traits::array_layout const& layout) {
return Impl::mapping_from_array_layout<typename base_t::mapping_type>(layout).required_span_size()*sizeof(value_type);
return Impl::mapping_from_array_layout<
typename traits::array_layout, typename base_t::mapping_type>(
layout)
.required_span_size() *
sizeof(value_type);
}

static constexpr size_t required_allocation_size(
Expand Down
20 changes: 16 additions & 4 deletions core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class ReferenceCountedDataHandle {
ReferenceCountedDataHandle(OtherElementType* ptr)
: m_tracker(), m_handle(ptr) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
ReferenceCountedDataHandle(
const ReferenceCountedDataHandle<OtherElementType, memory_space>& other)
: m_tracker(other.m_tracker), m_handle(other.m_handle) {}

ReferenceCountedDataHandle(const ReferenceCountedDataHandle&) = default;
ReferenceCountedDataHandle(ReferenceCountedDataHandle&&) noexcept = default;
ReferenceCountedDataHandle& operator=(const ReferenceCountedDataHandle&) =
Expand All @@ -256,8 +263,8 @@ class ReferenceCountedDataHandle {
std::string get_label() const { return m_tracker.get_label<memory_space>(); }

private:

friend class ReferenceCountedDataHandle<ElementType, AnonymousSpace>;
template <class OtherElementType, class OtherSpace>
friend class ReferenceCountedDataHandle;
SharedAllocationTracker m_tracker;
pointer m_handle = nullptr;
};
Expand Down Expand Up @@ -288,9 +295,11 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
default;
ReferenceCountedDataHandle& operator=(ReferenceCountedDataHandle&&) = default;

template <class OtherSpace>
template <class OtherElementType, class OtherSpace,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
ReferenceCountedDataHandle(
const ReferenceCountedDataHandle<ElementType, OtherSpace>& other)
const ReferenceCountedDataHandle<OtherElementType, OtherSpace>& other)
: m_tracker(other.m_tracker), m_handle(other.m_handle) {}

ReferenceCountedDataHandle with_offset(size_t offset) const {
Expand All @@ -309,6 +318,9 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
std::string get_label() const { return m_tracker.get_label<memory_space>(); }

private:
template <class OtherElementType, class OtherSpace>
friend class ReferenceCountedDataHandle;

SharedAllocationTracker m_tracker;
pointer m_handle = nullptr;
};
Expand Down

0 comments on commit 58ea2f3

Please sign in to comment.