Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] Revise NodeManager's implementation due to weak memory order #2008

Merged
merged 3 commits into from
Oct 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ class KernelManager::Impl {
if (compiled_structs_.need_snode_lists_data) {
auto *mem_alloc = reinterpret_cast<MemoryAllocator *>(addr);
// Make sure the retured memory address is always greater than 1.
mem_alloc->next = shaders::kAlignment;
mem_alloc->next = shaders::MemoryAllocator::kInitOffset;
// root list data are static
ListgenElement root_elem;
root_elem.mem_offset = 0;
Expand Down
80 changes: 34 additions & 46 deletions taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,18 @@ STR(
// clang-format on
constant constexpr int kTaichiMaxNumIndices = 8;
constant constexpr int kTaichiNumChunks = 1024;
constant constexpr int kAlignment = 8;
using PtrOffset = int32_t;

struct MemoryAllocator { atomic_int next; };
struct MemoryAllocator {
atomic_int next;

constant constexpr static int kInitOffset = 8;

static inline bool is_valid(PtrOffset v) {
return v >= kInitOffset;
}
};

// ListManagerData manages a list of elements with adjustable size.
struct ListManagerData {
Expand All @@ -44,6 +54,28 @@ STR(
atomic_int next;

atomic_int chunks[kTaichiNumChunks];

struct ReservedElemPtrOffset {
public:
ReservedElemPtrOffset() = default;
explicit ReservedElemPtrOffset(PtrOffset v) : val_(v) {
}

inline bool is_valid() const {
return is_valid(val_);
}

inline static bool is_valid(PtrOffset v) {
return MemoryAllocator::is_valid(v);
}

inline PtrOffset value() const {
return val_;
}

private:
PtrOffset val_{0};
};
};

// NodeManagerData stores the actual data needed to implement NodeManager
Expand All @@ -54,6 +86,7 @@ STR(
// few lists (ListManagerData). In particular, |data_list| stores the actual
// data, while |free_list| and |recycle_list| are only meant for GC.
struct NodeManagerData {
using ElemIndex = ListManagerData::ReservedElemPtrOffset;
// Stores the actual data.
ListManagerData data_list;
// For GC
Expand All @@ -62,51 +95,6 @@ STR(
atomic_int free_list_used;
// Need this field to bookkeep some data during GC
int recycled_list_size_backup;

// Use this type instead of the raw index type (int32_t), because the
// raw value needs to be shifted by |kIndexOffset| in order for the
// spinning memory allocation algorithm to work.
struct ElemIndex {
// The first 8 index values are reserved to encode special status:
// * 0 : nullptr
// * 1 : spinning for allocation
// * 2-7: unused for now
//
/// For each allocated index, it is added by |index_offset| to skip over
/// these reserved values.
constant static constexpr int32_t kIndexOffset = 8;

ElemIndex() = default;

static ElemIndex from_index(int i) {
return ElemIndex(i + kIndexOffset);
}

static ElemIndex from_raw(int r) {
return ElemIndex(r);
}

inline int32_t index() const {
return raw_ - kIndexOffset;
}

inline int32_t raw() const {
return raw_;
}

inline bool is_valid() const {
return raw_ >= kIndexOffset;
}

inline static bool is_valid(int raw) {
return ElemIndex::from_raw(raw).is_valid();
}

private:
explicit ElemIndex(int r) : raw_(r) {
}
int32_t raw_ = 0;
};
};

// This class is very similar to metal::SNodeDescriptor
Expand Down
82 changes: 46 additions & 36 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ struct Runtime {
// clang-format off
METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
using PtrOffset = int32_t;
constant constexpr int kAlignment = 8;

[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator *ma,
int32_t size) {
size = ((size + kAlignment - 1) / kAlignment) * kAlignment;
Expand All @@ -57,6 +54,7 @@ STR(
}

struct ListManager {
using ReservedElemPtrOffset = ListManagerData::ReservedElemPtrOffset;
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;

Expand All @@ -74,22 +72,19 @@ STR(
resize(0);
}

struct ReserveElemResult {
int elem_idx;
PtrOffset chunk_ptr_offs;
};

ReserveElemResult reserve_new_elem() {
ReservedElemPtrOffset reserve_new_elem() {
const int elem_idx = atomic_fetch_add_explicit(
&lm_data->next, 1, metal::memory_order_relaxed);
const int chunk_idx = elem_idx >> lm_data->log2_num_elems_per_chunk;
const int chunk_idx = get_chunk_index(elem_idx);
const PtrOffset chunk_ptr_offs = ensure_chunk(chunk_idx);
return {elem_idx, chunk_ptr_offs};
const auto offset =
get_elem_ptr_offs_from_chunk(elem_idx, chunk_ptr_offs);
return ReservedElemPtrOffset{offset};
}

device char *append() {
auto reserved = reserve_new_elem();
return get_elem_from_chunk(reserved.elem_idx, reserved.chunk_ptr_offs);
return get_ptr(reserved);
}

template <typename T>
Expand All @@ -104,8 +99,12 @@ STR(
}
}

device char *get_ptr(ReservedElemPtrOffset offs) {
return mtl_memalloc_to_ptr(mem_alloc, offs.value());
}

device char *get_ptr(int i) {
const int chunk_idx = i >> lm_data->log2_num_elems_per_chunk;
const int chunk_idx = get_chunk_index(i);
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
lm_data->chunks + chunk_idx, metal::memory_order_relaxed);
return get_elem_from_chunk(i, chunk_ptr_offs);
Expand All @@ -117,7 +116,11 @@ STR(
}

private:
PtrOffset ensure_chunk(int i) {
inline int get_chunk_index(int elem_idx) const {
return elem_idx >> lm_data->log2_num_elems_per_chunk;
}

PtrOffset ensure_chunk(int chunk_idx) {
PtrOffset offs = 0;
const int chunk_bytes =
(lm_data->element_stride << lm_data->log2_num_elems_per_chunk);
Expand All @@ -128,11 +131,11 @@ STR(
// from requesting memory again. Once allocated, set chunks[i] to the
// actual address offset, which is guaranteed to be greater than 1.
const bool is_me = atomic_compare_exchange_weak_explicit(
lm_data->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
lm_data->chunks + chunk_idx, &stored, 1,
metal::memory_order_relaxed, metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(mem_alloc, chunk_bytes);
atomic_store_explicit(lm_data->chunks + i, offs,
atomic_store_explicit(lm_data->chunks + chunk_idx, offs,
metal::memory_order_relaxed);
break;
} else if (stored > 1) {
Expand All @@ -144,11 +147,16 @@ STR(
return offs;
}

device char *get_elem_from_chunk(int i, PtrOffset chunk_ptr_offs) {
device char *chunk_ptr = reinterpret_cast<device char *>(
mtl_memalloc_to_ptr(mem_alloc, chunk_ptr_offs));
PtrOffset get_elem_ptr_offs_from_chunk(int elem_idx,
PtrOffset chunk_ptr_offs) {
const uint32_t mask = ((1 << lm_data->log2_num_elems_per_chunk) - 1);
return chunk_ptr + ((i & mask) * lm_data->element_stride);
return chunk_ptr_offs + ((elem_idx & mask) * lm_data->element_stride);
}

device char *get_elem_from_chunk(int elem_idx, PtrOffset chunk_ptr_offs) {
const auto offs =
get_elem_ptr_offs_from_chunk(elem_idx, chunk_ptr_offs);
return mtl_memalloc_to_ptr(mem_alloc, offs);
}
};

Expand All @@ -172,15 +180,15 @@ STR(
return free_list.get<ElemIndex>(cur_used);
}

return ElemIndex::from_index(data_list.reserve_new_elem().elem_idx);
return data_list.reserve_new_elem();
}

device byte *get(ElemIndex i) {
ListManager data_list;
data_list.lm_data = &(nm_data->data_list);
data_list.mem_alloc = mem_alloc;

return data_list.get_ptr(i.index());
return data_list.get_ptr(i);
}

void recycle(ElemIndex i) {
Expand Down Expand Up @@ -328,33 +336,35 @@ STR(

void activate(int i) {
device auto *nm_idx_ptr = to_nodemgr_idx_ptr(addr_, i);
auto nm_idx_raw =
auto nm_idx_val =
atomic_load_explicit(nm_idx_ptr, metal::memory_order_relaxed);
while (!ElemIndex::is_valid(nm_idx_raw)) {
nm_idx_raw = 0;
while (!ElemIndex::is_valid(nm_idx_val)) {
nm_idx_val = 0;
// See ListManager::ensure_chunk() for the allocation algorithm.
// See also https://github.com/taichi-dev/taichi/issues/1174.
const bool is_me = atomic_compare_exchange_weak_explicit(
nm_idx_ptr, &nm_idx_raw, 1, metal::memory_order_relaxed,
nm_idx_ptr, &nm_idx_val, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
nm_idx_raw = nm_.allocate().raw();
atomic_store_explicit(nm_idx_ptr, nm_idx_raw,
nm_idx_val = nm_.allocate().value();
atomic_store_explicit(nm_idx_ptr, nm_idx_val,
metal::memory_order_relaxed);
break;
} else if (ElemIndex::is_valid(nm_idx_raw)) {
} else if (ElemIndex::is_valid(nm_idx_val)) {
break;
}
// |nm_idx_raw| == 1, just spin
// |nm_idx_val| == 1, just spin
}
}

void deactivate(int i) {
device auto *nm_idx_ptr = to_nodemgr_idx_ptr(addr_, i);
const auto old_nm_idx_raw = atomic_exchange_explicit(
const auto old_nm_idx_val = atomic_exchange_explicit(
nm_idx_ptr, 0, metal::memory_order_relaxed);
const auto old_nm_idx = ElemIndex::from_raw(old_nm_idx_raw);
if (!old_nm_idx.is_valid()) return;
const auto old_nm_idx = ElemIndex(old_nm_idx_val);
if (!old_nm_idx.is_valid()) {
return;
}
nm_.recycle(old_nm_idx);
}

Expand All @@ -366,8 +376,8 @@ STR(

static inline ElemIndex to_nodemgr_idx(device byte * addr, int ch_i) {
device auto *ptr = to_nodemgr_idx_ptr(addr, ch_i);
const auto r = atomic_load_explicit(ptr, metal::memory_order_relaxed);
return ElemIndex::from_raw(r);
const auto v = atomic_load_explicit(ptr, metal::memory_order_relaxed);
return ElemIndex(v);
}

static bool is_active(device byte * addr, int ch_i) {
Expand Down