Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci][fix] Fix cuda_exp ci #5438

Merged
merged 10 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ on:
env:
github_actions: 'true'
os_name: linux
task: cuda
conda_env: test-env

jobs:
test:
name: ${{ matrix.tree_learner }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }})
name: ${{ matrix.task }} ${{ matrix.cuda_version }} ${{ matrix.method }} (linux, ${{ matrix.compiler }}, Python ${{ matrix.python_version }})
runs-on: [self-hosted, linux]
timeout-minutes: 60
strategy:
Expand All @@ -27,27 +26,27 @@ jobs:
compiler: gcc
python_version: "3.8"
cuda_version: "11.7.1"
tree_learner: cuda
task: cuda
- method: pip
compiler: clang
python_version: "3.9"
cuda_version: "10.0"
tree_learner: cuda
task: cuda
- method: wheel
compiler: gcc
python_version: "3.10"
cuda_version: "9.0"
tree_learner: cuda
task: cuda
- method: source
compiler: gcc
python_version: "3.8"
cuda_version: "11.7.1"
tree_learner: cuda_exp
task: cuda_exp
- method: pip
compiler: clang
python_version: "3.9"
cuda_version: "10.0"
tree_learner: cuda_exp
task: cuda_exp
steps:
- name: Setup or update software on host machine
run: |
Expand Down Expand Up @@ -86,7 +85,7 @@ jobs:
GITHUB_ACTIONS=${{ env.github_actions }}
OS_NAME=${{ env.os_name }}
COMPILER=${{ matrix.compiler }}
TASK=${{ env.task }}
TASK=${{ matrix.task }}
METHOD=${{ matrix.method }}
CONDA_ENV=${{ env.conda_env }}
PYTHON_VERSION=${{ matrix.python_version }}
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/cuda/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ class CUDAVector {
return data_;
}

const T* RawDataReadOnly() const {
return data_;
}

private:
T* data_;
size_t size_;
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class TreeLearner {
*/
virtual void ResetConfig(const Config* config) = 0;

/*!
* \brief Reset boosting_on_gpu_
* \param boosting_on_gpu flag for boosting on GPU
*/
virtual void ResetBoostingOnGPU(const bool /*boosting_on_gpu*/) {}

virtual void SetForcedSplit(const Json* forced_split_json) = 0;

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/cuda/cuda_score_updater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ CUDAScoreUpdater::CUDAScoreUpdater(const Dataset* data, int num_tree_per_iterati
has_init_score_ = true;
CopyFromHostToCUDADevice<double>(cuda_score_, init_score, total_size, __FILE__, __LINE__);
} else {
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(num_data_), __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_score_, 0, static_cast<size_t>(total_size), __FILE__, __LINE__);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (boosting_on_cuda_) {
Expand Down
8 changes: 3 additions & 5 deletions src/boosting/cuda/cuda_score_updater.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,22 @@ namespace LightGBM {

__global__ void AddScoreConstantKernel(
const double val,
const size_t offset,
const data_size_t num_data,
double* score) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
score[data_index + offset] += val;
score[data_index] += val;
}
}

void CUDAScoreUpdater::LaunchAddScoreConstantKernel(const double val, const size_t offset) {
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
Log::Debug("Adding init score = %lf", val);
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
AddScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, num_data_, cuda_score_ + offset);
}

__global__ void MultiplyScoreConstantKernel(
const double val,
const size_t offset,
const data_size_t num_data,
double* score) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
Expand All @@ -39,7 +37,7 @@ __global__ void MultiplyScoreConstantKernel(

void CUDAScoreUpdater::LaunchMultiplyScoreConstantKernel(const double val, const size_t offset) {
const int num_blocks = (num_data_ + num_threads_per_block_) / num_threads_per_block_;
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, offset, num_data_, cuda_score_);
MultiplyScoreConstantKernel<<<num_blocks, num_threads_per_block_>>>(val, num_data_, cuda_score_ + offset);
}

} // namespace LightGBM
Expand Down
115 changes: 79 additions & 36 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ GBDT::GBDT()
average_output_ = false;
tree_learner_ = nullptr;
linear_tree_ = false;
gradients_pointer_ = nullptr;
hessians_pointer_ = nullptr;
boosting_on_gpu_ = false;
}

GBDT::~GBDT() {
Expand Down Expand Up @@ -95,9 +98,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective

is_constant_hessian_ = GetIsConstHessian(objective_function);

const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type,
config_.get(), boosting_on_gpu));
config_.get(), boosting_on_gpu_));

// init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_);
Expand All @@ -112,7 +115,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective

#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp")) {
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu));
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_));
} else {
#endif // USE_CUDA_EXP
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
Expand All @@ -123,9 +126,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
num_data_ = train_data_->num_data();
// create buffer for gradients and Hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) {
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
if (gradients_pointer_ != nullptr) {
CHECK_NOTNULL(hessians_pointer_);
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
} else {
Expand All @@ -137,17 +145,14 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
#ifdef USE_CUDA_EXP
}
#endif // USE_CUDA_EXP
#ifndef USE_CUDA_EXP
}
#else // USE_CUDA_EXP
} else {
if (config_->device_type == std::string("cuda_exp")) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
}
} else if (config_->boosting == std::string("goss")) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}
#endif // USE_CUDA_EXP

// get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
Expand Down Expand Up @@ -440,23 +445,36 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
Boosting();
gradients = gradients_pointer_;
hessians = hessians_pointer_;
#ifndef USE_CUDA_EXP
}
#else // USE_CUDA_EXP
} else {
if (config_->device_type == std::string("cuda_exp")) {
const size_t total_size = static_cast<size_t>(num_data_ * num_class_);
CopyFromHostToCUDADevice<score_t>(gradients_pointer_, gradients, total_size, __FILE__, __LINE__);
CopyFromHostToCUDADevice<score_t>(hessians_pointer_, hessians, total_size, __FILE__, __LINE__);
// use customized objective function
CHECK(objective_function_ == nullptr);
if (config_->boosting == std::string("goss")) {
// need to copy customized gradients when using GOSS
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < total_size; ++i) {
gradients_[i] = gradients[i];
hessians_[i] = hessians[i];
}
CHECK_EQ(gradients_pointer_, gradients_.data());
CHECK_EQ(hessians_pointer_, hessians_.data());
Comment on lines +459 to +460
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these checks slow? Maybe move them under #ifdef DEBUG?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. These checks are fast.

gradients = gradients_pointer_;
hessians = hessians_pointer_;
}
}
#endif // USE_CUDA_EXP

// bagging logic
Bagging(iter_);

if (gradients != nullptr && is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_ && config_->boosting != std::string("goss")) {
// allocate gradients_ and hessians_ for copy gradients for using data subset
int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}

bool should_continue = false;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
const size_t offset = static_cast<size_t>(cur_tree_id) * num_data_;
Expand All @@ -465,7 +483,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
auto grad = gradients + offset;
auto hess = hessians + offset;
// need to copy gradients for bagging subset.
if (is_use_subset_ && bag_data_cnt_ < num_data_ && config_->device_type != std::string("cuda_exp")) {
if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) {
for (int i = 0; i < bag_data_cnt_; ++i) {
gradients_pointer_[offset + i] = grad[bag_data_indices_[i]];
hessians_pointer_[offset + i] = hess[bag_data_indices_[i]];
Expand Down Expand Up @@ -591,13 +609,12 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {

std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
#ifdef USE_CUDA_EXP
const bool boosting_on_cuda = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
const bool evaluation_on_cuda = metric->IsCUDAMetric();
if ((boosting_on_cuda && evaluation_on_cuda) || (!boosting_on_cuda && !evaluation_on_cuda)) {
if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) {
#endif // USE_CUDA_EXP
return metric->Eval(score, objective_function_);
#ifdef USE_CUDA_EXP
} else if (boosting_on_cuda && !evaluation_on_cuda) {
} else if (boosting_on_gpu_ && !evaluation_on_cuda) {
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_);
if (total_size > host_score_.size()) {
host_score_.resize(total_size, 0.0f);
Expand Down Expand Up @@ -804,17 +821,16 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
}
training_metrics_.shrink_to_fit();

#ifdef USE_CUDA_EXP
const bool boosting_on_gpu = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
#endif // USE_CUDA_EXP
boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);

if (train_data != train_data_) {
train_data_ = train_data;
// not same training data, need reset score and others
// create score tracker
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp")) {
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu));
train_score_updater_.reset(new CUDAScoreUpdater(train_data_, num_tree_per_iteration_, boosting_on_gpu_));
} else {
#endif // USE_CUDA_EXP
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
Expand All @@ -834,9 +850,14 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*

// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu) {
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
if (gradients_pointer_ != nullptr) {
CHECK_NOTNULL(hessians_pointer_);
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
} else {
Expand All @@ -848,6 +869,12 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
#ifdef USE_CUDA_EXP
}
#endif // USE_CUDA_EXP
} else if (config_->boosting == std::string("goss")) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}

max_feature_idx_ = train_data_->num_total_features() - 1;
Expand Down Expand Up @@ -879,6 +906,10 @@ void GBDT::ResetConfig(const Config* config) {
if (tree_learner_ != nullptr) {
tree_learner_->ResetConfig(new_config.get());
}

boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective();
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);

if (train_data_ != nullptr) {
ResetBaggingConfig(new_config.get(), false);
}
Expand Down Expand Up @@ -953,10 +984,16 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
need_re_bagging_ = true;

if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (objective_function_ == nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
// resize gradient vectors to copy the customized gradients for goss or bagging with subset
if (objective_function_ != nullptr) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#ifdef USE_CUDA_EXP
if (config_->device_type == std::string("cuda_exp") && objective_function_ != nullptr && objective_function_->IsCUDAObjective()) {
if (config_->device_type == std::string("cuda_exp") && boosting_on_gpu_) {
if (gradients_pointer_ != nullptr) {
CHECK_NOTNULL(hessians_pointer_);
DeallocateCUDAMemory<score_t>(&gradients_pointer_, __FILE__, __LINE__);
DeallocateCUDAMemory<score_t>(&hessians_pointer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<score_t>(&gradients_pointer_, total_size, __FILE__, __LINE__);
AllocateCUDAMemory<score_t>(&hessians_pointer_, total_size, __FILE__, __LINE__);
} else {
Expand All @@ -968,6 +1005,12 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
#ifdef USE_CUDA_EXP
}
#endif // USE_CUDA_EXP
} else if (config_->boosting == std::string("goss")) {
const size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
gradients_pointer_ = gradients_.data();
hessians_pointer_ = hessians_.data();
}
}
} else {
Expand Down
2 changes: 2 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ class GBDT : public GBDTBase {
score_t* gradients_pointer_;
/*! \brief Pointer to hessian vector, can be on CPU or GPU */
score_t* hessians_pointer_;
/*! \brief Whether boosting is done on GPU, used for cuda_exp */
bool boosting_on_gpu_;
#ifdef USE_CUDA_EXP
/*! \brief Buffer for scores when boosting is on GPU but evaluation is not, used only with cuda_exp */
mutable std::vector<double> host_score_;
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/rf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class RF : public GBDT {
auto hess = hessians + offset;

// need to copy gradients for bagging subset.
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (is_use_subset_ && bag_data_cnt_ < num_data_ && !boosting_on_gpu_) {
for (int i = 0; i < bag_data_cnt_; ++i) {
tmp_grad_[i] = grad[bag_data_indices_[i]];
tmp_hess_[i] = hess[bag_data_indices_[i]];
Expand Down
Loading