Skip to content

Commit

Permalink
fix bug in parallel learning (#2851)
Browse files Browse the repository at this point in the history
* refix

* fix config

* avoid to rely on config
  • Loading branch information
guolinke authored Mar 2, 2020
1 parent 9c386db commit da91c61
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 68 deletions.
2 changes: 1 addition & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ struct Config {
size_t file_load_progress_interval_bytes = size_t(10) * 1024 * 1024 * 1024;

bool is_parallel = false;
bool is_parallel_find_bin = false;
bool is_data_based_parallel = false;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_set<std::string>& parameter_set();
Expand Down
4 changes: 2 additions & 2 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ void Application::LoadData() {
}

// sync up random seed for data partition
if (config_.is_parallel_find_bin) {
if (config_.is_data_based_parallel) {
config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
}

Log::Debug("Loading train file...");
DatasetLoader dataset_loader(config_, predict_fun,
config_.num_class, config_.data.c_str());
// load Training data
if (config_.is_parallel_find_bin) {
if (config_.is_data_based_parallel) {
// load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
Network::rank(), Network::num_machines()));
Expand Down
4 changes: 2 additions & 2 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@ void Config::CheckParamConflict() {
}

if (is_single_tree_learner || tree_learner == std::string("feature")) {
is_parallel_find_bin = false;
is_data_based_parallel = false;
} else if (tree_learner == std::string("data")
|| tree_learner == std::string("voting")) {
is_parallel_find_bin = true;
is_data_based_parallel = true;
if (histogram_pool_size >= 0
&& tree_learner == std::string("data")) {
Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const

template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
TREELEARNER_T::Split(tree, best_Leaf, left_leaf, right_leaf);
this->SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// need update global number of data in leaf
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
Expand Down
142 changes: 82 additions & 60 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,89 +648,111 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
return result_count;
}

void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf) {
Common::FunctionTimer fun_timer("SerialTreeLearner::Split", global_timer);
void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf, bool update_cnt) {
Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);
SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature);
const int inner_feature_index =
train_data_->InnerFeatureIndex(best_split_info.feature);
if (cegb_ != nullptr) {
cegb_->UpdateLeafBestSplits(tree, best_leaf, &best_split_info, &best_split_per_leaf_);
cegb_->UpdateLeafBestSplits(tree, best_leaf, &best_split_info,
&best_split_per_leaf_);
}
*left_leaf = best_leaf;
auto next_leaf_id = tree->NextLeafId();

bool is_numerical_split = train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin;
bool is_numerical_split =
train_data_->FeatureBinMapper(inner_feature_index)->bin_type() ==
BinType::NumericalBin;
if (is_numerical_split) {
auto threshold_double = train_data_->RealThreshold(inner_feature_index, best_split_info.threshold);
auto threshold_double = train_data_->RealThreshold(
inner_feature_index, best_split_info.threshold);
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
&best_split_info.threshold, 1, best_split_info.default_left, next_leaf_id);
best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
&best_split_info.threshold, 1,
best_split_info.default_left, next_leaf_id);
if (update_cnt) {
// don't need to update this in data-based parallel model
best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
}
// split tree, will return right leaf
*right_leaf = tree->Split(best_leaf,
inner_feature_index,
best_split_info.feature,
best_split_info.threshold,
threshold_double,
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
*right_leaf = tree->Split(
best_leaf, inner_feature_index, best_split_info.feature,
best_split_info.threshold, threshold_double,
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
} else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(best_split_info.cat_threshold.data(), best_split_info.num_cat_threshold);
std::vector<uint32_t> cat_bitset_inner =
Common::ConstructBitset(best_split_info.cat_threshold.data(),
best_split_info.num_cat_threshold);
std::vector<int> threshold_int(best_split_info.num_cat_threshold);
for (int i = 0; i < best_split_info.num_cat_threshold; ++i) {
threshold_int[i] = static_cast<int>(train_data_->RealThreshold(inner_feature_index, best_split_info.cat_threshold[i]));
threshold_int[i] = static_cast<int>(train_data_->RealThreshold(
inner_feature_index, best_split_info.cat_threshold[i]));
}
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(threshold_int.data(), best_split_info.num_cat_threshold);
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
threshold_int.data(), best_split_info.num_cat_threshold);

data_partition_->Split(best_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()), best_split_info.default_left, next_leaf_id);

best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);

*right_leaf = tree->SplitCategorical(best_leaf,
inner_feature_index,
best_split_info.feature,
cat_bitset_inner.data(),
static_cast<int>(cat_bitset_inner.size()),
cat_bitset.data(),
static_cast<int>(cat_bitset.size()),
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
}

#ifdef DEBUG
cat_bitset_inner.data(),
static_cast<int>(cat_bitset_inner.size()),
best_split_info.default_left, next_leaf_id);

if (update_cnt) {
// don't need to update this in data-based parallel model
best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
}

*right_leaf = tree->SplitCategorical(
best_leaf, inner_feature_index, best_split_info.feature,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
cat_bitset.data(), static_cast<int>(cat_bitset.size()),
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
}

#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id);
#endif
#endif

// init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
}
constraints_->UpdateConstraints(
is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output,
best_split_info.left_output);
}
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
}
constraints_->UpdateConstraints(is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type,
best_split_info.right_output,
best_split_info.left_output);

}

void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
Expand Down
7 changes: 6 additions & 1 deletion src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ class SerialTreeLearner: public TreeLearner {
* \param left_leaf The index of left leaf after splitted.
* \param right_leaf The index of right leaf after splitted.
*/
virtual void Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf);
inline virtual void Split(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf) {
SplitInner(tree, best_leaf, left_leaf, right_leaf, true);
}

void SplitInner(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf, bool update_cnt);

/* Force splits with forced_split_json dict and then return num splits forced.*/
virtual int32_t ForceSplits(Tree* tree, const Json& forced_split_json, int* left_leaf,
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/voting_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons

template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
TREELEARNER_T::Split(tree, best_Leaf, left_leaf, right_leaf);
this->SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
Expand Down

0 comments on commit da91c61

Please sign in to comment.