From be7f321382b22a1ab536965c5fdfefe131e488f6 Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Mon, 29 Aug 2022 14:38:04 +0800 Subject: [PATCH] [ci][fix] Fix cuda_exp ci (#5438) * fix cuda_exp ci * fix ci failures introduced by #5279 * cleanup cuda.yml * fix test.sh * clean up test.sh * clean up test.sh * skip lines by cuda_exp in test_register_logger * Update tests/python_package_test/test_utilities.py Co-authored-by: Nikita Titov Co-authored-by: Nikita Titov --- .github/workflows/cuda.yml | 15 ++- include/LightGBM/cuda/cuda_utils.h | 4 + include/LightGBM/tree_learner.h | 6 + src/boosting/cuda/cuda_score_updater.cpp | 2 +- src/boosting/cuda/cuda_score_updater.cu | 8 +- src/boosting/gbdt.cpp | 115 ++++++++++++------ src/boosting/gbdt.h | 2 + src/boosting/rf.hpp | 2 +- src/io/cuda/cuda_tree.cu | 46 +++++-- src/objective/objective_function.cpp | 2 +- .../cuda/cuda_single_gpu_tree_learner.cpp | 10 ++ .../cuda/cuda_single_gpu_tree_learner.hpp | 4 +- tests/python_package_test/test_utilities.py | 7 +- 13 files changed, 159 insertions(+), 64 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 671434bffbd5..54a7aa1e45eb 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -11,12 +11,11 @@ on: env: github_actions: 'true' os_name: linux - task: cuda conda_env: test-env jobs: test: - name: ${{ matrix.tree_learner }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }}) + name: ${{ matrix.task }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }}) runs-on: [self-hosted, linux] timeout-minutes: 60 strategy: @@ -27,27 +26,27 @@ jobs: compiler: gcc python_version: "3.8" cuda_version: "11.7.1" - tree_learner: cuda + task: cuda - method: pip compiler: clang python_version: "3.9" cuda_version: "10.0" - tree_learner: cuda + task: cuda - method: wheel compiler: gcc python_version: "3.10" cuda_version: "9.0" - tree_learner: cuda + task: cuda - method: source compiler: gcc python_version: "3.8" cuda_version: "11.7.1" - tree_learner: cuda_exp + task: cuda_exp - method: pip compiler: clang python_version: "3.9" cuda_version: "10.0" - tree_learner: cuda_exp + task: cuda_exp steps: - name: Setup or update software on host machine run: | @@ -86,7 +85,7 @@ jobs: GITHUB_ACTIONS=${{ env.github_actions }} OS_NAME=${{ env.os_name }} COMPILER=${{ matrix.compiler }} - TASK=${{ env.task }} + TASK=${{ matrix.task }} METHOD=${{ matrix.method }} CONDA_ENV=${{ env.conda_env }} PYTHON_VERSION=${{ matrix.python_version }} diff --git a/include/LightGBM/cuda/cuda_utils.h b/include/LightGBM/cuda/cuda_utils.h index ee88c52a0404..21aee1641ee4 100644 --- a/include/LightGBM/cuda/cuda_utils.h +++ b/include/LightGBM/cuda/cuda_utils.h @@ -171,6 +171,10 @@ class CUDAVector { return data_; } + const T* RawDataReadOnly() const { + return data_; + } + private: T* data_; size_t size_; diff --git a/include/LightGBM/tree_learner.h b/include/LightGBM/tree_learner.h index 197a80f18cd7..772e1422ff69 100644 --- a/include/LightGBM/tree_learner.h +++ b/include/LightGBM/tree_learner.h @@ -50,6 +50,12 @@ class TreeLearner { */ virtual void ResetConfig(const Config* config) = 0; + /*! + * \brief Reset boosting_on_gpu_ + * \param boosting_on_gpu flag for boosting on GPU + */ + virtual void ResetBoostingOnGPU(const bool /*boosting_on_gpu*/) {} + virtual void SetForcedSplit(const Json* forced_split_json) = 0; /*! diff --git a/src/boosting/cuda/cuda_score_updater.cpp b/src/boosting/cuda/cuda_score_updater.cpp index 7e0b681e6df9..9c514265ee40 100644 --- a/src/boosting/cuda/cuda_score_updater.cpp +++ b/src/boosting/cuda/cuda_score_updater.cpp @@ -25,7 +25,7 @@ CUDAScoreUpdater::CUDAScoreUpdater(const Dataset* data, int num_tree_per_iterati has_init_score_ = true; CopyFromHostToCUDADevice(cuda_score_, init_score, total_size, __FILE__, __LINE__); } else { - SetCUDAMemory(cuda_score_, 0, static_cast(num_data_), __FILE__, __LINE__); + SetCUDAMemory(cuda_score_, 0, static_cast(total_size), __FILE__, __LINE__); } SynchronizeCUDADevice(__FILE__, __LINE__); if (boosting_on_cuda_) { diff --git a/src/boosting/cuda/cuda_score_updater.cu b/src/boosting/cuda/cuda_score_updater.cu index 1645d25d490f..c2138957f199 100644 --- a/src/boosting/cuda/cuda_score_updater.cu +++ b/src/boosting/cuda/cuda_score_updater.cu @@ -11,24 +11,22 @@ namespace LightGBM { __global__ void AddScoreConstantKernel( const double val, - const size_t offset, const data_size_t num_data, double* score) { const data_size_t data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); if (data_index < num_data) { - score[data_index + offset] += val; + score[data_index] += val; } } void CUDAScoreUpdater::LaunchAddScoreConstantKernel(const double val, const size_t offset) { const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_; Log::Debug("Adding init score = %lf", val); - AddScoreConstantKernel<<>>(val, offset, num_data_, cuda_score_); + AddScoreConstantKernel<<>>(val, num_data_, cuda_score_ + offset); } __global__ void MultiplyScoreConstantKernel( const double val, - const size_t offset, const data_size_t num_data, double* score) { const data_size_t data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); @@ -39,7 +37,7 @@ __global__ void MultiplyScoreConstantKernel( void CUDAScoreUpdater::LaunchMultiplyScoreConstantKernel(const double val, const size_t offset) { const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_; - MultiplyScoreConstantKernel<<>>(val, offset, num_data_, cuda_score_); + MultiplyScoreConstantKernel<<>>(val, num_data_, cuda_score_ + offset); } } // namespace LightGBM diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 9d48e535ea0e..ee29263bb7b6 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -41,6 +41,9 @@ GBDT::GBDT() average_output_ = false; tree_learner_ = nullptr; linear_tree_ = false; + gradients_pointer_ = nullptr; + hessians_pointer_ = nullptr; + boosting_on_gpu_ = false; } GBDT::~GBDT() { @@ -95,9 +98,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective is_constant_hessian_ = GetIsConstHessian(objective_function); - const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective(); + boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective(); tree_learner_ = std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, - config_.get(), boosting_on_gpu)); + config_.get(), boosting_on_gpu_)); // init tree learner tree_learner_->Init(train_data_, is_constant_hessian_); @@ -112,7 +115,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective #ifdef USE_CUDA_EXP if (config_->device_type == std::string("cuda_exp")) { - train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu)); + train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_)); } else { #endif // USE_CUDA_EXP train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_)); @@ -123,9 +126,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective num_data_ = train_data_->num_data(); // create buffer for gradients and Hessians if (objective_function_ != nullptr) { - size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + const size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; #ifdef USE_CUDA_EXP - if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) { + if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) { + if (gradients_pointer_ != nullptr) { + CHECK_NOTNULL(hessians_pointer_); + DeallocateCUDAMemory(&gradients_pointer_, __FILE__, __LINE__); + DeallocateCUDAMemory(&hessians_pointer_, __FILE__, __LINE__); + } AllocateCUDAMemory(&gradients_pointer_, total_size, __FILE__, __LINE__); AllocateCUDAMemory(&hessians_pointer_, total_size, __FILE__, __LINE__); } else { @@ -137,17 +145,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective #ifdef USE_CUDA_EXP } #endif // USE_CUDA_EXP - #ifndef USE_CUDA_EXP - } - #else // USE_CUDA_EXP - } else { - if (config_->device_type == std::string("cuda_exp")) { - size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; - AllocateCUDAMemory(&gradients_pointer_, total_size, __FILE__, __LINE__); - AllocateCUDAMemory(&hessians_pointer_, total_size, __FILE__, __LINE__); - } + } else if (config_->boosting == std::string("goss")) { + const size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + gradients_.resize(total_size); + hessians_.resize(total_size); + gradients_pointer_ = gradients_.data(); + hessians_pointer_ = hessians_.data(); } - #endif // USE_CUDA_EXP + // get max feature index max_feature_idx_ = train_data_->num_total_features() - 1; // get label index @@ -440,23 +445,36 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { Boosting(); gradients = gradients_pointer_; hessians = hessians_pointer_; - #ifndef USE_CUDA_EXP - } - #else // USE_CUDA_EXP } else { - if (config_->device_type == std::string("cuda_exp")) { - const size_t total_size = static_cast(num_data_ * num_class_); - CopyFromHostToCUDADevice(gradients_pointer_, gradients, total_size, __FILE__, __LINE__); - CopyFromHostToCUDADevice(hessians_pointer_, hessians, total_size, __FILE__, __LINE__); + // use customized objective function + CHECK(objective_function_ == nullptr); + if (config_->boosting == std::string("goss")) { + // need to copy customized gradients when using GOSS + int64_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + #pragma omp parallel for schedule(static) + for (int64_t i = 0; i < total_size; ++i) { + gradients_[i] = gradients[i]; + hessians_[i] = hessians[i]; + } + CHECK_EQ(gradients_pointer_, gradients_.data()); + CHECK_EQ(hessians_pointer_, hessians_.data()); gradients = gradients_pointer_; hessians = hessians_pointer_; } } - #endif // USE_CUDA_EXP // bagging logic Bagging(iter_); + if (gradients != nullptr && is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_ && config_->boosting != std::string("goss")) { + // allocate gradients_ and hessians_ for copy gradients for using data subset + int64_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + gradients_.resize(total_size); + hessians_.resize(total_size); + gradients_pointer_ = gradients_.data(); + hessians_pointer_ = hessians_.data(); + } + bool should_continue = false; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { const size_t offset = static_cast(cur_tree_id) * num_data_; @@ -465,7 +483,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { auto grad = gradients + offset; auto hess = hessians + offset; // need to copy gradients for bagging subset. - if (is_use_subset_ && bag_data_cnt_ < num_data_ && config_->device_type != std::string("cuda_exp")) { + if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) { for (int i = 0; i < bag_data_cnt_; ++i) { gradients_pointer_[offset + i] = grad[bag_data_indices_[i]]; hessians_pointer_[offset + i] = hess[bag_data_indices_[i]]; @@ -591,13 +609,12 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) { std::vector GBDT::EvalOneMetric(const Metric* metric, const double* score) const { #ifdef USE_CUDA_EXP - const bool boosting_on_cuda = objective_function_ != nullptr && objective_function_->IsCUDAObjective(); const bool evaluation_on_cuda = metric->IsCUDAMetric(); - if ((boosting_on_cuda && evaluation_on_cuda) || (!boosting_on_cuda && !evaluation_on_cuda)) { + if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) { #endif // USE_CUDA_EXP return metric->Eval(score, objective_function_); #ifdef USE_CUDA_EXP - } else if (boosting_on_cuda && !evaluation_on_cuda) { + } else if (boosting_on_gpu_ && !evaluation_on_cuda) { const size_t total_size = static_cast(num_data_) * static_cast(num_tree_per_iteration_); if (total_size > host_score_.size()) { host_score_.resize(total_size, 0.0f); @@ -804,9 +821,8 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* } training_metrics_.shrink_to_fit(); - #ifdef USE_CUDA_EXP - const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective(); - #endif // USE_CUDA_EXP + boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective(); + tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_); if (train_data != train_data_) { train_data_ = train_data; @@ -814,7 +830,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* // create score tracker #ifdef USE_CUDA_EXP if (config_->device_type == std::string("cuda_exp")) { - train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu)); + train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_)); } else { #endif // USE_CUDA_EXP train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_)); @@ -834,9 +850,14 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* // create buffer for gradients and hessians if (objective_function_ != nullptr) { - size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + const size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; #ifdef USE_CUDA_EXP - if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) { + if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) { + if (gradients_pointer_ != nullptr) { + CHECK_NOTNULL(hessians_pointer_); + DeallocateCUDAMemory(&gradients_pointer_, __FILE__, __LINE__); + DeallocateCUDAMemory(&hessians_pointer_, __FILE__, __LINE__); + } AllocateCUDAMemory(&gradients_pointer_, total_size, __FILE__, __LINE__); AllocateCUDAMemory(&hessians_pointer_, total_size, __FILE__, __LINE__); } else { @@ -848,6 +869,12 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* #ifdef USE_CUDA_EXP } #endif // USE_CUDA_EXP + } else if (config_->boosting == std::string("goss")) { + const size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + gradients_.resize(total_size); + hessians_.resize(total_size); + gradients_pointer_ = gradients_.data(); + hessians_pointer_ = hessians_.data(); } max_feature_idx_ = train_data_->num_total_features() - 1; @@ -879,6 +906,10 @@ void GBDT::ResetConfig(const Config* config) { if (tree_learner_ != nullptr) { tree_learner_->ResetConfig(new_config.get()); } + + boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective(); + tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_); + if (train_data_ != nullptr) { ResetBaggingConfig(new_config.get(), false); } @@ -953,10 +984,16 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { need_re_bagging_ = true; if (is_use_subset_ && bag_data_cnt_ < num_data_) { - if (objective_function_ == nullptr) { - size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + // resize gradient vectors to copy the customized gradients for goss or bagging with subset + if (objective_function_ != nullptr) { + const size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; #ifdef USE_CUDA_EXP - if (config_->device_type == std::string("cuda_exp") && objective_function_ != nullptr && objective_function_->IsCUDAObjective()) { + if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) { + if (gradients_pointer_ != nullptr) { + CHECK_NOTNULL(hessians_pointer_); + DeallocateCUDAMemory(&gradients_pointer_, __FILE__, __LINE__); + DeallocateCUDAMemory(&hessians_pointer_, __FILE__, __LINE__); + } AllocateCUDAMemory(&gradients_pointer_, total_size, __FILE__, __LINE__); AllocateCUDAMemory(&hessians_pointer_, total_size, __FILE__, __LINE__); } else { @@ -968,6 +1005,12 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { #ifdef USE_CUDA_EXP } #endif // USE_CUDA_EXP + } else if (config_->boosting == std::string("goss")) { + const size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; + gradients_.resize(total_size); + hessians_.resize(total_size); + gradients_pointer_ = gradients_.data(); + hessians_pointer_ = hessians_.data(); } } } else { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index f699719b525e..052fd0ea9051 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -504,6 +504,8 @@ class GBDT : public GBDTBase { score_t* gradients_pointer_; /*! \brief Pointer to hessian vector, can be on CPU or GPU */ score_t* hessians_pointer_; + /*! \brief Whether boosting is done on GPU, used for cuda_exp */ + bool boosting_on_gpu_; #ifdef USE_CUDA_EXP /*! \brief Buffer for scores when boosting is on GPU but evaluation is not, used only with cuda_exp */ mutable std::vector host_score_; diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp index 5a9eb226fef5..8755a2eb0e2e 100644 --- a/src/boosting/rf.hpp +++ b/src/boosting/rf.hpp @@ -116,7 +116,7 @@ class RF : public GBDT { auto hess = hessians + offset; // need to copy gradients for bagging subset. - if (is_use_subset_ && bag_data_cnt_ < num_data_) { + if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) { for (int i = 0; i < bag_data_cnt_; ++i) { tmp_grad_[i] = grad[bag_data_indices_[i]]; tmp_hess_[i] = hess[bag_data_indices_[i]]; diff --git a/src/io/cuda/cuda_tree.cu b/src/io/cuda/cuda_tree.cu index a15c83e89fae..2a6448259d7f 100644 --- a/src/io/cuda/cuda_tree.cu +++ b/src/io/cuda/cuda_tree.cu @@ -35,6 +35,16 @@ __device__ bool IsZeroCUDA(double fval) { return (fval >= -kZeroThreshold && fval <= kZeroThreshold); } +template +__device__ bool FindInBitsetCUDA(const uint32_t* bits, int n, T pos) { + int i1 = pos / 32; + if (i1 >= n) { + return false; + } + int i2 = pos % 32; + return (bits[i1] >> i2) & 1; +} + __global__ void SplitKernel( // split information const int leaf_index, const int real_feature_index, @@ -323,11 +333,13 @@ __global__ void AddPredictionToScoreKernel( const int* cuda_left_child, const int* cuda_right_child, const double* cuda_leaf_value, + const uint32_t* cuda_bitset_inner, + const int* cuda_cat_boundaries_inner, // output double* score) { const data_size_t inner_data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); - const data_size_t data_index = USE_INDICES ? cuda_used_indices[inner_data_index] : inner_data_index; - if (data_index < num_data) { + if (inner_data_index < num_data) { + const data_size_t data_index = USE_INDICES ? cuda_used_indices[inner_data_index] : inner_data_index; int node = 0; while (node >= 0) { const int split_feature_inner = cuda_split_feature_inner[node]; @@ -352,20 +364,30 @@ __global__ void AddPredictionToScoreKernel( bin = most_freq_bin; } const int8_t decision_type = cuda_decision_type[node]; - const uint32_t threshold_in_bin = cuda_threshold_in_bin[node]; - const int8_t missing_type = ((decision_type >> 2) & 3); - const bool default_left = ((decision_type & kDefaultLeftMask) > 0); - if ((missing_type == 1 && bin == default_bin) || (missing_type == 2 && bin == max_bin)) { - if (default_left) { + if (GetDecisionTypeCUDA(decision_type, kCategoricalMask)) { + int cat_idx = static_cast(cuda_threshold_in_bin[node]); + if (FindInBitsetCUDA(cuda_bitset_inner + cuda_cat_boundaries_inner[cat_idx], + cuda_cat_boundaries_inner[cat_idx + 1] - cuda_cat_boundaries_inner[cat_idx], bin)) { node = cuda_left_child[node]; } else { node = cuda_right_child[node]; } } else { - if (bin <= threshold_in_bin) { - node = cuda_left_child[node]; + const uint32_t threshold_in_bin = cuda_threshold_in_bin[node]; + const int8_t missing_type = GetMissingTypeCUDA(decision_type); + const bool default_left = ((decision_type & kDefaultLeftMask) > 0); + if ((missing_type == 1 && bin == default_bin) || (missing_type == 2 && bin == max_bin)) { + if (default_left) { + node = cuda_left_child[node]; + } else { + node = cuda_right_child[node]; + } } else { - node = cuda_right_child[node]; + if (bin <= threshold_in_bin) { + node = cuda_left_child[node]; + } else { + node = cuda_right_child[node]; + } } } } @@ -400,6 +422,8 @@ void CUDATree::LaunchAddPredictionToScoreKernel( cuda_left_child_, cuda_right_child_, cuda_leaf_value_, + cuda_bitset_inner_.RawDataReadOnly(), + cuda_cat_boundaries_inner_.RawDataReadOnly(), // output score); } else { @@ -422,6 +446,8 @@ void CUDATree::LaunchAddPredictionToScoreKernel( cuda_left_child_, cuda_right_child_, cuda_leaf_value_, + cuda_bitset_inner_.RawDataReadOnly(), + cuda_cat_boundaries_inner_.RawDataReadOnly(), // output score); } diff --git a/src/objective/objective_function.cpp b/src/objective/objective_function.cpp index 04bfc45df6c7..2f719b44a988 100644 --- a/src/objective/objective_function.cpp +++ b/src/objective/objective_function.cpp @@ -14,7 +14,7 @@ namespace LightGBM { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) { #ifdef USE_CUDA_EXP - if (config.device_type == std::string("cuda_exp")) { + if (config.device_type == std::string("cuda_exp") && config.boosting == std::string("gbdt")) { if (type == std::string("regression")) { Log::Warning("Objective regression is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new RegressionL2loss(config); diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp index 55595765fcc6..fd48201e2c3a 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp @@ -447,6 +447,16 @@ void CUDASingleGPUTreeLearner::AllocateBitset() { cuda_bitset_inner_len_ = 0; } +void CUDASingleGPUTreeLearner::ResetBoostingOnGPU(const bool boosting_on_cuda) { + boosting_on_cuda_ = boosting_on_cuda; + DeallocateCUDAMemory(&cuda_gradients_, __FILE__, __LINE__); + DeallocateCUDAMemory(&cuda_hessians_, __FILE__, __LINE__); + if (!boosting_on_cuda_) { + AllocateCUDAMemory(&cuda_gradients_, static_cast(num_data_), __FILE__, __LINE__); + AllocateCUDAMemory(&cuda_hessians_, static_cast(num_data_), __FILE__, __LINE__); + } +} + #ifdef DEBUG void CUDASingleGPUTreeLearner::CheckSplitValid( const int left_leaf, diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp index 942a6d1cb17d..a55f9df8fc15 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp @@ -49,6 +49,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, const score_t* gradients, const score_t* hessians) const override; + void ResetBoostingOnGPU(const bool boosting_on_gpu) override; + protected: void BeforeTrain() override; @@ -119,7 +121,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { /*! \brief hessians on CUDA */ score_t* cuda_hessians_; /*! \brief whether boosting is done on CUDA */ - const bool boosting_on_cuda_; + bool boosting_on_cuda_; #ifdef DEBUG /*! \brief gradients on CPU */ diff --git a/tests/python_package_test/test_utilities.py b/tests/python_package_test/test_utilities.py index 1b16550c5d11..08913ceb6e38 100644 --- a/tests/python_package_test/test_utilities.py +++ b/tests/python_package_test/test_utilities.py @@ -91,11 +91,16 @@ def dummy_metric(_, __): "INFO | [LightGBM] [Warning] CUDA currently requires double precision calculations.", "INFO | [LightGBM] [Info] LightGBM using CUDA trainer with DP float!!" ] + cuda_exp_lines = [ + "INFO | [LightGBM] [Warning] Objective binary is not implemented in cuda_exp version. Fall back to boosting on CPU.", + "INFO | [LightGBM] [Warning] Metric auc is not implemented in cuda_exp version. Fall back to evaluation on CPU.", + "INFO | [LightGBM] [Warning] Metric binary_error is not implemented in cuda_exp version. Fall back to evaluation on CPU.", + ] with open(log_filename, "rt", encoding="utf-8") as f: actual_log = f.read().strip() actual_log_wo_gpu_stuff = [] for line in actual_log.split("\n"): - if not any(line.startswith(gpu_line) for gpu_line in gpu_lines): + if not any(line.startswith(gpu_or_cuda_exp_line) for gpu_or_cuda_exp_line in gpu_lines + cuda_exp_lines): actual_log_wo_gpu_stuff.append(line) assert "\n".join(actual_log_wo_gpu_stuff) == expected_log