Skip to content

Commit

Permalink
Start investigating group split.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 28, 2020
1 parent d3d03e4 commit 81e2a0e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/charconv.cc"
#include "../src/common/categorical.cc"
#include "../src/common/timer.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/hist_util.cc"
Expand Down
12 changes: 12 additions & 0 deletions src/common/categorical.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*!
* Copyright 2020 by XGBoost Contributors
* \file categorical.cc
*/

#include "categorical.h"

namespace xgboost {
namespace common {
DMLC_REGISTER_PARAMETER(CategoricalParam);
} // namespace common
} // namespace xgboost
12 changes: 12 additions & 0 deletions src/common/categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/span.h"
#include "xgboost/parameter.h"

namespace xgboost {
namespace common {
Expand Down Expand Up @@ -43,6 +44,17 @@ inline XGBOOST_DEVICE bool GoLeft(float l, float r, bool is_cat, Comp comp) {
}
return comp(l, r);
}

struct CategoricalParam : XGBoostParameter<CategoricalParam> {
int32_t max_cat_to_onehot {-1};

DMLC_DECLARE_PARAMETER(CategoricalParam) {
DMLC_DECLARE_FIELD(max_cat_to_onehot)
.set_lower_bound(-1)
.set_default(-1)
.describe("Maximum number of categories for one hot split.");
}
};
} // namespace common
} // namespace xgboost

Expand Down
4 changes: 4 additions & 0 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ class GPUHistMaker : public TreeUpdater {
// The passed in args can be empty, if we simply purge the old maker without
// preserving parameters then we can't do Update on it.
TrainParam param;
cat_param_.UpdateAllowUnknown(args);
if (float_maker_) {
param = float_maker_->param_;
} else if (double_maker_) {
Expand All @@ -884,6 +885,7 @@ class GPUHistMaker : public TreeUpdater {
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
FromJson(config.at("categorical_param"), &this->cat_param_);
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
FromJson(config.at("train_param"), &float_maker_->param_);
Expand All @@ -895,6 +897,7 @@ class GPUHistMaker : public TreeUpdater {
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
out["categorical_param"] = ToJson(cat_param_);
if (hist_maker_param_.single_precision_histogram) {
out["train_param"] = ToJson(float_maker_->param_);
} else {
Expand Down Expand Up @@ -926,6 +929,7 @@ class GPUHistMaker : public TreeUpdater {

private:
GPUHistMakerTrainParam hist_maker_param_;
common::CategoricalParam cat_param_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
};
Expand Down

0 comments on commit 81e2a0e

Please sign in to comment.