diff --git a/src/common/common.h b/src/common/common.h index aaf2b8a37dc0..0ead15c3fbec 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -6,6 +6,8 @@ #ifndef XGBOOST_COMMON_COMMON_H_ #define XGBOOST_COMMON_COMMON_H_ +#include + #include #include #include @@ -35,6 +37,71 @@ inline std::string ToString(const T& data) { return os.str(); } +/* + * Range iterator + */ +class Range { + public: + class Iterator { + friend class Range; + + public: + using DifferenceType = int64_t; + + XGBOOST_DEVICE int64_t operator*() const { return i_; } + XGBOOST_DEVICE const Iterator &operator++() { + i_ += step_; + return *this; + } + XGBOOST_DEVICE Iterator operator++(int) { + Iterator res {*this}; + i_ += step_; + return res; + } + + XGBOOST_DEVICE bool operator==(const Iterator &other) const { + return i_ >= other.i_; + } + XGBOOST_DEVICE bool operator!=(const Iterator &other) const { + return i_ < other.i_; + } + + XGBOOST_DEVICE void Step(DifferenceType s) { step_ = s; } + + protected: + XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {} + XGBOOST_DEVICE explicit Iterator(int64_t start, int step) : + i_{start}, step_{step} {} + + public: + int64_t i_; + DifferenceType step_ = 1; + }; + + XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT + XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT + + XGBOOST_DEVICE Range(int64_t begin, int64_t end) + : begin_(begin), end_(end) {} + XGBOOST_DEVICE Range(int64_t begin, int64_t end, Iterator::DifferenceType step) + : begin_(begin, step), end_(end) {} + + XGBOOST_DEVICE bool operator==(const Range& other) const { + return *begin_ == *other.begin_ && *end_ == *other.end_; + } + XGBOOST_DEVICE bool operator!=(const Range& other) const { + return !(*this == other); + } + + XGBOOST_DEVICE void Step(Iterator::DifferenceType s) { begin_.Step(s); } + + XGBOOST_DEVICE Iterator::DifferenceType GetStep() const { return begin_.step_; } + + private: + Iterator begin_; + Iterator end_; +}; + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 6358a005782e..844edad59e0e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -7,6 +7,10 @@ #include #include #include + +#include "common.h" +#include "gpu_set.h" + #include #include #include @@ -28,25 +32,6 @@ namespace dh { #define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__ #define DEV_INLINE __device__ __forceinline__ -/* - * Error handling functions - */ - -#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) - -inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, - int line) { - if (code != cudaSuccess) { - std::stringstream ss; - ss << file << "(" << line << ")"; - std::string file_and_line; - ss >> file_and_line; - throw thrust::system_error(code, thrust::cuda_category(), file_and_line); - } - - return code; -} - #ifdef XGBOOST_USE_NCCL #define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__) @@ -73,47 +58,22 @@ const T *Raw(const thrust::device_vector &v) { // NOLINT return raw_pointer_cast(v.data()); } -inline int NVisibleDevices() { - int n_visgpus = 0; - - dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); - - return n_visgpus; -} - -inline int NDevicesAll(int n_gpus) { - int n_devices_visible = dh::NVisibleDevices(); - int n_devices = n_gpus < 0 ? n_devices_visible : n_gpus; - return (n_devices); -} -inline int NDevices(int n_gpus, int num_rows) { - int n_devices = dh::NDevicesAll(n_gpus); - // fix-up device number to be limited by number of rows - n_devices = n_devices > num_rows ? num_rows : n_devices; - return (n_devices); -} - // if n_devices=-1, then use all visible devices -inline void SynchronizeNDevices(int n_devices, std::vector dList) { - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - safe_cuda(cudaSetDevice(device_idx)); +inline void SynchronizeNDevices(xgboost::GPUSet devices) { + devices = devices.IsEmpty() ? xgboost::GPUSet::AllVisible() : devices; + for (auto const d : devices.Unnormalised()) { + safe_cuda(cudaSetDevice(d)); safe_cuda(cudaDeviceSynchronize()); } } + inline void SynchronizeAll() { - for (int device_idx = 0; device_idx < NVisibleDevices(); device_idx++) { + for (int device_idx : xgboost::GPUSet::AllVisible()) { safe_cuda(cudaSetDevice(device_idx)); safe_cuda(cudaDeviceSynchronize()); } } -inline std::string DeviceName(int device_idx) { - cudaDeviceProp prop; - dh::safe_cuda(cudaGetDeviceProperties(&prop, device_idx)); - return std::string(prop.name); -} - inline size_t AvailableMemory(int device_idx) { size_t device_free = 0; size_t device_total = 0; @@ -144,15 +104,8 @@ inline size_t MaxSharedMemory(int device_idx) { return prop.sharedMemPerBlock; } -// ensure gpu_id is correct, so not dependent upon user knowing details -inline int GetDeviceIdx(int gpu_id) { - // protect against overrun for gpu_id - return (std::abs(gpu_id) + 0) % dh::NVisibleDevices(); -} - inline void CheckComputeCapability() { - int n_devices = NVisibleDevices(); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + for (int d_idx : xgboost::GPUSet::AllVisible()) { cudaDeviceProp prop; safe_cuda(cudaGetDeviceProperties(&prop, d_idx)); std::ostringstream oss; @@ -163,12 +116,11 @@ inline void CheckComputeCapability() { } } - DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) { atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8)); } -/*! +/*! * \brief Find the strict upper bound for an element in a sorted array * using binary search. * \param cuts pointer to the first element of the sorted array @@ -199,67 +151,18 @@ DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) { return right; } -/* - * Range iterator - */ - -class Range { - public: - class Iterator { - friend class Range; - - public: - XGBOOST_DEVICE int64_t operator*() const { return i_; } - XGBOOST_DEVICE const Iterator &operator++() { - i_ += step_; - return *this; - } - XGBOOST_DEVICE Iterator operator++(int) { - Iterator copy(*this); - i_ += step_; - return copy; - } - - XGBOOST_DEVICE bool operator==(const Iterator &other) const { - return i_ >= other.i_; - } - XGBOOST_DEVICE bool operator!=(const Iterator &other) const { - return i_ < other.i_; - } - - XGBOOST_DEVICE void Step(int s) { step_ = s; } - - protected: - XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {} - - public: - uint64_t i_; - int step_ = 1; - }; - - XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT - XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT - XGBOOST_DEVICE Range(int64_t begin, int64_t end) - : begin_(begin), end_(end) {} - XGBOOST_DEVICE void Step(int s) { begin_.Step(s); } - - private: - Iterator begin_; - Iterator end_; -}; - template -__device__ Range GridStrideRange(T begin, T end) { +__device__ xgboost::common::Range GridStrideRange(T begin, T end) { begin += blockDim.x * blockIdx.x + threadIdx.x; - Range r(begin, end); + xgboost::common::Range r(begin, end); r.Step(gridDim.x * blockDim.x); return r; } template -__device__ Range BlockStrideRange(T begin, T end) { +__device__ xgboost::common::Range BlockStrideRange(T begin, T end) { begin += threadIdx.x; - Range r(begin, end); + xgboost::common::Range r(begin, end); r.Step(blockDim.x); return r; } @@ -557,7 +460,7 @@ class BulkAllocator { BulkAllocator(BulkAllocator&&) = delete; void operator=(const BulkAllocator&) = delete; void operator=(BulkAllocator&&) = delete; - + ~BulkAllocator() { for (size_t i = 0; i < d_ptr_.size(); i++) { if (!(d_ptr_[i] == nullptr)) { diff --git a/src/common/gpu_set.h b/src/common/gpu_set.h new file mode 100644 index 000000000000..ed9b595a19bd --- /dev/null +++ b/src/common/gpu_set.h @@ -0,0 +1,122 @@ +/*! + * Copyright 2018 XGBoost contributors + */ +#ifndef XGBOOST_COMMON_GPU_SET_H_ +#define XGBOOST_COMMON_GPU_SET_H_ + +#include +#include + +#include +#include + +#include "common.h" +#include "span.h" + +#if defined(__CUDACC__) +#include +#include +#endif + +namespace dh { +#if defined(__CUDACC__) +/* + * Error handling functions + */ +#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) + +inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, + int line) { + if (code != cudaSuccess) { + throw thrust::system_error(code, thrust::cuda_category(), + std::string{file} + "(" + // NOLINT + std::to_string(line) + ")"); + } + return code; +} +#endif +} // namespace dh + +namespace xgboost { + +/* \brief set of devices across which HostDeviceVector can be distributed. + * + * Currently implemented as a range, but can be changed later to something else, + * e.g. a bitset + */ +class GPUSet { + public: + explicit GPUSet(int start = 0, int ndevices = 0) + : devices_(start, start + ndevices) {} + + static GPUSet Empty() { return GPUSet(); } + + static GPUSet Range(int start, int ndevices) { + return ndevices <= 0 ? Empty() : GPUSet{start, ndevices}; + } + /* \brief ndevices and num_rows both are upper bounds. */ + static GPUSet All(int ndevices, int num_rows = std::numeric_limits::max()) { + int n_devices_visible = AllVisible().Size(); + ndevices = ndevices < 0 ? n_devices_visible : ndevices; + // fix-up device number to be limited by number of rows + ndevices = ndevices > num_rows ? num_rows : ndevices; + return Range(0, ndevices); + } + + static GPUSet AllVisible() { + int n_visgpus = 0; +#if defined(__CUDACC__) + dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); +#endif + return Range(0, n_visgpus); + } + /* \brief Ensure gpu_id is correct, so not dependent upon user knowing details */ + static int GetDeviceIdx(int gpu_id) { + return (std::abs(gpu_id) + 0) % AllVisible().Size(); + } + /* \brief Counting from gpu_id */ + GPUSet Normalised(int gpu_id) const { + return Range(gpu_id, *devices_.end() + gpu_id); + } + /* \brief Counting from 0 */ + GPUSet Unnormalised() const { + return Range(0, *devices_.end() - *devices_.begin()); + } + + int Size() const { + int res = *devices_.end() - *devices_.begin(); + return res < 0 ? 0 : res; + } + + int operator[](int index) const { + CHECK(index >= 0 && index < *(devices_.end())); + return *devices_.begin() + index; + } + + bool IsEmpty() const { return Size() == 0; } // NOLINT + + int Index(int device) const { + CHECK(Contains(device)); + return device - *devices_.begin(); + } + + bool Contains(int device) const { + return *devices_.begin() <= device && device < *devices_.end(); + } + + common::Range::Iterator begin() const { return devices_.begin(); } // NOLINT + common::Range::Iterator end() const { return devices_.end(); } // NOLINT + + friend bool operator==(const GPUSet& lhs, const GPUSet& rhs) { + return lhs.devices_ == rhs.devices_; + } + friend bool operator!=(const GPUSet& lhs, const GPUSet& rhs) { + return !(lhs == rhs); + } + + private: + common::Range devices_; +}; +} // namespace xgboost + +#endif // XGBOOST_COMMON_GPU_SET_H_ diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index bb4453f57b88..894d877512da 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -271,7 +271,7 @@ struct GPUSketcher { find_cuts_k<<>> (cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(), weights2_.data().get(), n_unique, n_cuts_cur_[icol]); - dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaGetLastError()); // NOLINT } } @@ -311,14 +311,14 @@ struct GPUSketcher { has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), gpu_batch_nrows_, num_cols_, row_batch.offset[row_begin_ + batch_row_begin], batch_nrows); - dh::safe_cuda(cudaGetLastError()); - dh::safe_cuda(cudaDeviceSynchronize()); + dh::safe_cuda(cudaGetLastError()); // NOLINT + dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT for (int icol = 0; icol < num_cols_; ++icol) { FindColumnCuts(batch_nrows, icol); } - dh::safe_cuda(cudaDeviceSynchronize()); + dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT // add cuts into sketches thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); @@ -379,7 +379,7 @@ struct GPUSketcher { } GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) { - devices_ = GPUSet::Range(param_.gpu_id, dh::NDevices(param_.n_gpus, n_rows)); + devices_ = GPUSet::All(param_.n_gpus, n_rows).Normalised(param_.gpu_id); } std::vector> shards_; diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index ebd54e84988c..3bab1009541d 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -11,6 +11,7 @@ #include #include +#include "gpu_set.h" #include "span.h" // only include thrust-related files if host_device_vector.h @@ -23,40 +24,6 @@ namespace xgboost { template struct HostDeviceVectorImpl; -// set of devices across which HostDeviceVector can be distributed; -// currently implemented as a range, but can be changed later to something else, -// e.g. a bitset -class GPUSet { - public: - explicit GPUSet(int start = 0, int ndevices = 0) - : start_(start), ndevices_(ndevices) {} - static GPUSet Empty() { return GPUSet(); } - static GPUSet Range(int start, int ndevices) { return GPUSet(start, ndevices); } - int Size() const { return ndevices_; } - int operator[](int index) const { - CHECK(index >= 0 && index < ndevices_); - return start_ + index; - } - bool IsEmpty() const { return ndevices_ <= 0; } - int Index(int device) const { - CHECK(device >= start_ && device < start_ + ndevices_); - return device - start_; - } - bool Contains(int device) const { - return start_ <= device && device < start_ + ndevices_; - } - friend bool operator==(GPUSet a, GPUSet b) { - return a.start_ == b.start_ && a.ndevices_ == b.ndevices_; - } - friend bool operator!=(GPUSet a, GPUSet b) { - return a.start_ != b.start_ || a.ndevices_ != b.ndevices_; - } - - private: - int start_, ndevices_; -}; - - /** * @file host_device_vector.h * @brief A device-and-host vector abstraction layer. diff --git a/src/common/timer.h b/src/common/timer.h index 32460ae56b35..22c0cf31a901 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -7,7 +7,8 @@ #include #include #include -#include + +#include "gpu_set.h" namespace xgboost { namespace common { @@ -66,21 +67,21 @@ struct Monitor { this->label = label; } void Start(const std::string &name) { timer_map[name].Start(); } - void Start(const std::string &name, std::vector dList) { + void Start(const std::string &name, GPUSet devices) { if (debug_verbose) { #ifdef __CUDACC__ #include "device_helpers.cuh" - dh::SynchronizeNDevices(dList.size(), dList); + dh::SynchronizeNDevices(devices); #endif } timer_map[name].Start(); } void Stop(const std::string &name) { timer_map[name].Stop(); } - void Stop(const std::string &name, std::vector dList) { + void Stop(const std::string &name, GPUSet devices) { if (debug_verbose) { #ifdef __CUDACC__ #include "device_helpers.cuh" - dh::SynchronizeNDevices(dList.size(), dList); + dh::SynchronizeNDevices(devices); #endif } timer_map[name].Stop(); diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index ca1536cd1e36..84761caaabdc 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -6,6 +6,7 @@ #include #include #include +#include "../common/gpu_set.h" #include "../common/device_helpers.cuh" #include "../common/timer.h" #include "coordinate_common.h" @@ -214,14 +215,14 @@ class GPUCoordinateUpdater : public LinearUpdater { void LazyInitShards(DMatrix *p_fmat, const gbm::GBLinearModelParam &model_param) { if (!shards.empty()) return; - int n_devices = dh::NDevices(param.n_gpus, p_fmat->Info().num_row_); + int n_devices = GPUSet::All(param.n_gpus, p_fmat->Info().num_row_).Size(); bst_uint row_begin = 0; bst_uint shard_size = std::ceil(static_cast(p_fmat->Info().num_row_) / n_devices); device_list.resize(n_devices); for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - int device_idx = (param.gpu_id + d_idx) % dh::NVisibleDevices(); + int device_idx = GPUSet::GetDeviceIdx(param.gpu_id + d_idx); device_list[d_idx] = device_idx; } // Partition input matrix into row segments diff --git a/src/objective/regression_obj_gpu.cu b/src/objective/regression_obj_gpu.cu index d4ac49bfe28b..2525fb80276d 100644 --- a/src/objective/regression_obj_gpu.cu +++ b/src/objective/regression_obj_gpu.cu @@ -102,8 +102,8 @@ class GPURegLossObj : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); - CHECK(param_.n_gpus != 0) << "Must have at least one device"; - devices_ = GPUSet::Range(param_.gpu_id, dh::NDevicesAll(param_.n_gpus)); + // CHECK(param_.n_gpus != 0) << "Must have at least one device"; + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); } void GetGradient(HostDeviceVector* preds, diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 21afc37f3f9c..1fba61656abf 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -11,6 +11,7 @@ #include #include #include +#include "../common/gpu_set.h" #include "../common/device_helpers.cuh" #include "../common/host_device_vector.h" @@ -464,7 +465,7 @@ class GPUPredictor : public xgboost::Predictor { Predictor::Init(cfg, cache); cpu_predictor->Init(cfg, cache); param.InitAllowUnknown(cfg); - devices = GPUSet::Range(param.gpu_id, dh::NDevicesAll(param.n_gpus)); + devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id); max_shared_memory_bytes = dh::MaxSharedMemory(param.gpu_id); } diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index ee23299ac813..f549f1922d56 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -4,6 +4,7 @@ #include #include #include +#include "../common/gpu_set.h" #include "param.h" #include "updater_gpu_common.cuh" @@ -375,7 +376,7 @@ void argMaxByKey(ExactSplitCandidate* nodeSplits, const GradientPair* gradScans, NodeIdT nodeStart, int len, const TrainParam param, ArgMaxByKeyAlgo algo) { dh::FillConst( - dh::GetDeviceIdx(param.gpu_id), nodeSplits, nUniqKeys, + GPUSet::GetDeviceIdx(param.gpu_id), nodeSplits, nUniqKeys, ExactSplitCandidate()); int nBlks = dh::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM); switch (algo) { @@ -498,7 +499,7 @@ class GPUMaker : public TreeUpdater { // devices are only used for resharding the HostDeviceVector passed as a parameter; // the algorithm works with a single GPU only - GPUSet devices; + GPUSet devices_; dh::CubMemory tmp_mem; dh::DVec tmpScanGradBuff; @@ -516,7 +517,7 @@ class GPUMaker : public TreeUpdater { maxNodes = (1 << (param.max_depth + 1)) - 1; maxLeaves = 1 << param.max_depth; - devices = GPUSet::Range(param.gpu_id, dh::NDevicesAll(param.n_gpus)); + devices_ = GPUSet::All(param.n_gpus).Normalised(param.gpu_id); } void Update(HostDeviceVector* gpair, DMatrix* dmat, @@ -526,7 +527,7 @@ class GPUMaker : public TreeUpdater { float lr = param.learning_rate; param.learning_rate = lr / trees.size(); - gpair->Reshard(devices); + gpair->Reshard(devices_); try { // build tree @@ -624,7 +625,7 @@ class GPUMaker : public TreeUpdater { void allocateAllData(int offsetSize) { int tmpBuffSize = ScanTempBufferSize(nVals); - ba.Allocate(dh::GetDeviceIdx(param.gpu_id), param.silent, &vals, nVals, + ba.Allocate(GPUSet::GetDeviceIdx(param.gpu_id), param.silent, &vals, nVals, &vals_cached, nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets, offsetSize, &gradsInst, nRows, &nodeAssigns, nVals, &nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst, @@ -634,7 +635,7 @@ class GPUMaker : public TreeUpdater { } void setupOneTimeData(DMatrix* dmat) { - size_t free_memory = dh::AvailableMemory(dh::GetDeviceIdx(param.gpu_id)); + size_t free_memory = dh::AvailableMemory(GPUSet::GetDeviceIdx(param.gpu_id)); if (!dmat->SingleColBlock()) { throw std::runtime_error("exact::GPUBuilder - must have 1 column block"); } @@ -730,7 +731,7 @@ class GPUMaker : public TreeUpdater { nodeAssigns.Current(), instIds.Current(), nodes.Data(), colOffsets.Data(), vals.Current(), nVals, nCols); // gather the node assignments across all other columns too - dh::Gather(dh::GetDeviceIdx(param.gpu_id), nodeAssigns.Current(), + dh::Gather(GPUSet::GetDeviceIdx(param.gpu_id), nodeAssigns.Current(), nodeAssignsPerInst.Data(), instIds.Current(), nVals); sortKeys(level); } @@ -741,7 +742,7 @@ class GPUMaker : public TreeUpdater { // but we don't need more than level+1 bits for sorting! SegmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols, colOffsets, 0, level + 1); - dh::Gather(dh::GetDeviceIdx(param.gpu_id), vals.other(), + dh::Gather(GPUSet::GetDeviceIdx(param.gpu_id), vals.other(), vals.Current(), instIds.other(), instIds.Current(), nodeLocations.Current(), nVals); vals.buff().selector ^= 1; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 69c616af183b..ba2761d4c488 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -728,7 +728,7 @@ class GPUHistMaker : public TreeUpdater { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; n_devices_ = param_.n_gpus; - devices_ = GPUSet::Range(param_.gpu_id, dh::NDevicesAll(param_.n_gpus)); + devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); dh::CheckComputeCapability(); @@ -743,7 +743,7 @@ class GPUHistMaker : public TreeUpdater { void Update(HostDeviceVector* gpair, DMatrix* dmat, const std::vector& trees) override { - monitor_.Start("Update", device_list_); + monitor_.Start("Update", devices_); GradStats::CheckInfo(dmat->Info()); // rescale learning rate according to size of trees float lr = param_.learning_rate; @@ -759,17 +759,17 @@ class GPUHistMaker : public TreeUpdater { LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl; } param_.learning_rate = lr; - monitor_.Stop("Update", device_list_); + monitor_.Stop("Update", devices_); } void InitDataOnce(DMatrix* dmat) { info_ = &dmat->Info(); - int n_devices = dh::NDevices(param_.n_gpus, info_->num_row_); + int n_devices = GPUSet::All(param_.n_gpus, info_->num_row_).Size(); device_list_.resize(n_devices); for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - int device_idx = (param_.gpu_id + d_idx) % dh::NVisibleDevices(); + int device_idx = GPUSet::GetDeviceIdx(param_.gpu_id + d_idx); device_list_[d_idx] = device_idx; } @@ -792,16 +792,16 @@ class GPUHistMaker : public TreeUpdater { shard->InitRowPtrs(batch); }); - monitor_.Start("Quantiles", device_list_); + monitor_.Start("Quantiles", devices_); common::DeviceSketch(batch, *info_, param_, &hmat_); n_bins_ = hmat_.row_ptr.back(); - monitor_.Stop("Quantiles", device_list_); + monitor_.Stop("Quantiles", devices_); - monitor_.Start("BinningCompression", device_list_); + monitor_.Start("BinningCompression", devices_); dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { shard->InitCompressedData(hmat_, batch); }); - monitor_.Stop("BinningCompression", device_list_); + monitor_.Stop("BinningCompression", devices_); CHECK(!iter->Next()) << "External memory not supported"; @@ -811,20 +811,20 @@ class GPUHistMaker : public TreeUpdater { void InitData(HostDeviceVector* gpair, DMatrix* dmat, const RegTree& tree) { - monitor_.Start("InitDataOnce", device_list_); + monitor_.Start("InitDataOnce", devices_); if (!initialised_) { this->InitDataOnce(dmat); } - monitor_.Stop("InitDataOnce", device_list_); + monitor_.Stop("InitDataOnce", devices_); column_sampler_.Init(info_->num_col_, param_); // Copy gpair & reset memory - monitor_.Start("InitDataReset", device_list_); + monitor_.Start("InitDataReset", devices_); gpair->Reshard(devices_); dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) {shard->Reset(gpair); }); - monitor_.Stop("InitDataReset", device_list_); + monitor_.Stop("InitDataReset", devices_); } void AllReduceHist(int nidx) { @@ -1036,12 +1036,12 @@ class GPUHistMaker : public TreeUpdater { RegTree* p_tree) { auto& tree = *p_tree; - monitor_.Start("InitData", device_list_); + monitor_.Start("InitData", devices_); this->InitData(gpair, p_fmat, *p_tree); - monitor_.Stop("InitData", device_list_); - monitor_.Start("InitRoot", device_list_); + monitor_.Stop("InitData", devices_); + monitor_.Start("InitRoot", devices_); this->InitRoot(p_tree); - monitor_.Stop("InitRoot", device_list_); + monitor_.Stop("InitRoot", devices_); auto timestamp = qexpand_->size(); auto num_leaves = 1; @@ -1051,9 +1051,9 @@ class GPUHistMaker : public TreeUpdater { qexpand_->pop(); if (!candidate.IsValid(param_, num_leaves)) continue; // std::cout << candidate; - monitor_.Start("ApplySplit", device_list_); + monitor_.Start("ApplySplit", devices_); this->ApplySplit(candidate, p_tree); - monitor_.Stop("ApplySplit", device_list_); + monitor_.Stop("ApplySplit", devices_); num_leaves++; auto left_child_nidx = tree[candidate.nid].LeftChild(); @@ -1062,12 +1062,12 @@ class GPUHistMaker : public TreeUpdater { // Only create child entries if needed if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx), num_leaves)) { - monitor_.Start("BuildHist", device_list_); + monitor_.Start("BuildHist", devices_); this->BuildHistLeftRight(candidate.nid, left_child_nidx, right_child_nidx); - monitor_.Stop("BuildHist", device_list_); + monitor_.Stop("BuildHist", devices_); - monitor_.Start("EvaluateSplits", device_list_); + monitor_.Start("EvaluateSplits", devices_); auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree); qexpand_->push(ExpandEntry(left_child_nidx, @@ -1076,21 +1076,21 @@ class GPUHistMaker : public TreeUpdater { qexpand_->push(ExpandEntry(right_child_nidx, tree.GetDepth(right_child_nidx), splits[1], timestamp++)); - monitor_.Stop("EvaluateSplits", device_list_); + monitor_.Stop("EvaluateSplits", devices_); } } } bool UpdatePredictionCache( const DMatrix* data, HostDeviceVector* p_out_preds) override { - monitor_.Start("UpdatePredictionCache", device_list_); + monitor_.Start("UpdatePredictionCache", devices_); if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) return false; p_out_preds->Reshard(devices_); dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx)); }); - monitor_.Stop("UpdatePredictionCache", device_list_); + monitor_.Stop("UpdatePredictionCache", devices_); return true; } diff --git a/tests/cpp/common/test_gpuset.cc b/tests/cpp/common/test_gpuset.cc new file mode 100644 index 000000000000..3d74ba27037e --- /dev/null +++ b/tests/cpp/common/test_gpuset.cc @@ -0,0 +1,37 @@ +#include "../../../src/common/gpu_set.h" +#include + +namespace xgboost { + +TEST(GPUSet, Basic) { + GPUSet devices = GPUSet::Empty(); + ASSERT_TRUE(devices.IsEmpty()); + + devices = GPUSet{0, 1}; + ASSERT_TRUE(devices != GPUSet::Empty()); + EXPECT_EQ(devices.Size(), 1); + + EXPECT_ANY_THROW(devices.Index(1)); + EXPECT_ANY_THROW(devices.Index(-1)); + + devices = GPUSet::Range(1, 0); + EXPECT_EQ(devices, GPUSet::Empty()); + EXPECT_EQ(devices.Size(), 0); + EXPECT_TRUE(devices.IsEmpty()); + + EXPECT_FALSE(devices.Contains(1)); + + devices = GPUSet::Range(2, -1); + EXPECT_EQ(devices, GPUSet::Empty()); + EXPECT_EQ(devices.Size(), 0); + EXPECT_TRUE(devices.IsEmpty()); + + devices = GPUSet::Range(2, 8); + EXPECT_EQ(devices.Size(), 8); + devices = devices.Unnormalised(); + + EXPECT_EQ(*devices.begin(), 0); + EXPECT_EQ(*devices.end(), devices.Size()); +} + +} // namespace xgboost