Skip to content

Commit

Permalink
Move to threading utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 17, 2022
1 parent e5a4473 commit 0c25c93
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 33 deletions.
33 changes: 0 additions & 33 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,39 +642,6 @@ class GHistBuilder {
/*! \brief number of all bins over all features */
uint32_t nbins_ { 0 };
};

/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template <typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size) : required_size_(required_size) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
}
}

~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; }

private:
T* ptr_ = nullptr;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_HIST_UTIL_H_
34 changes: 34 additions & 0 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,40 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
n_threads = std::max(n_threads, 1);
return n_threads;
}


/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
* heap-allocated.
*/
template <typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size) : required_size_(required_size) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
}
}

~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; }

private:
T* ptr_ = nullptr;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
} // namespace common
} // namespace xgboost

Expand Down
1 change: 1 addition & 0 deletions src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/threading_utils.h"

namespace xgboost {

Expand Down

0 comments on commit 0c25c93

Please sign in to comment.