diff --git a/R-package/tests/testthat/test_interaction_constraints.R b/R-package/tests/testthat/test_interaction_constraints.R index 1b4902576e61..9a3ddf442f1b 100644 --- a/R-package/tests/testthat/test_interaction_constraints.R +++ b/R-package/tests/testthat/test_interaction_constraints.R @@ -14,25 +14,42 @@ test_that("interaction constraints for regression", { bst <- xgboost(data = train, label = y, max_depth = 3, eta = 0.1, nthread = 2, nrounds = 100, verbose = 0, interaction_constraints = list(c(0,1))) - + # Set all observations to have the same x3 values then increment # by the same amount - preds <- lapply(c(1,2,3), function(x){ - tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3) - return(predict(bst, tmat)) - }) + preds <- lapply(c(1,2,3), function(x){ + tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3) + return(predict(bst, tmat)) + }) # Check incrementing x3 has the same effect on all observations # since x3 is constrained to be independent of x1 and x2 # and all observations start off from the same x3 value - diff1 <- preds[[2]] - preds[[1]] - test1 <- all(abs(diff1 - diff1[1]) < 1e-4) - - diff2 <- preds[[3]] - preds[[2]] - test2 <- all(abs(diff2 - diff2[1]) < 1e-4) - + diff1 <- preds[[2]] - preds[[1]] + test1 <- all(abs(diff1 - diff1[1]) < 1e-4) + + diff2 <- preds[[3]] - preds[[2]] + test2 <- all(abs(diff2 - diff2[1]) < 1e-4) + expect_true({ test1 & test2 }, "Interaction Contraint Satisfied") - +}) + +test_that("interaction constraints scientific representation", { + rows <- 10 + ## When number exceeds 1e5, R paste function uses scientific representation. + ## See: https://github.com/dmlc/xgboost/issues/5179 + cols <- 1e5+10 + + d <- matrix(rexp(rows, rate=.1), nrow=rows, ncol=cols) + y <- rnorm(rows) + + dtrain <- xgb.DMatrix(data=d, info = list(label=y)) + inc <- list(c(seq.int(from = 0, to = cols, by = 1))) + + with_inc <- xgb.train(data=dtrain, tree_method='hist', + interaction_constraints=inc, nrounds=10) + without_inc <- xgb.train(data=dtrain, tree_method='hist', nrounds=10) + expect_equal(xgb.save.raw(with_inc), xgb.save.raw(without_inc)) }) diff --git a/src/tree/constraints.cc b/src/tree/constraints.cc index 5e5d440f7147..a1aa9dc0570e 100644 --- a/src/tree/constraints.cc +++ b/src/tree/constraints.cc @@ -6,6 +6,7 @@ #include #include "xgboost/span.h" +#include "xgboost/json.h" #include "constraints.h" #include "param.h" @@ -27,15 +28,12 @@ void FeatureInteractionConstraintHost::Reset() { if (!enabled_) { return; } - // Parse interaction constraints - std::istringstream iss(this->interaction_constraint_str_); - dmlc::JSONReader reader(&iss); - // Read std::vector> first and then - // convert to std::vector> - std::vector> tmp; + // Read std::vector> first and then + // convert to std::vector> + std::vector> tmp; try { - reader.Read(&tmp); - } catch (dmlc::Error const& e) { + ParseInteractionConstraint(this->interaction_constraint_str_, &tmp); + } catch (dmlc::Error const &e) { LOG(FATAL) << "Failed to parse feature interaction constraint:\n" << this->interaction_constraint_str_ << "\n" << "With error:\n" << e.what(); diff --git a/src/tree/constraints.cu b/src/tree/constraints.cu index 371ee790c14c..b6db0eda0739 100644 --- a/src/tree/constraints.cu +++ b/src/tree/constraints.cu @@ -7,9 +7,7 @@ #include #include -#include #include -#include #include #include "xgboost/logging.h" @@ -18,14 +16,13 @@ #include "param.h" #include "../common/device_helpers.cuh" - namespace xgboost { -size_t FeatureInteractionConstraint::Features() const { +size_t FeatureInteractionConstraintDevice::Features() const { return d_sets_ptr_.size() - 1; } -void FeatureInteractionConstraint::Configure( +void FeatureInteractionConstraintDevice::Configure( tree::TrainParam const& param, int32_t const n_features) { has_constraint_ = true; if (param.interaction_constraints.length() == 0) { @@ -33,13 +30,11 @@ void FeatureInteractionConstraint::Configure( return; } // --- Parse interaction constraints - std::istringstream iss(param.interaction_constraints); - dmlc::JSONReader reader(&iss); // Interaction constraints parsed from string parameter. After // parsing, this looks like {{0, 1, 2}, {2, 3 ,4}}. - std::vector> h_feature_constraints; + std::vector> h_feature_constraints; try { - reader.Read(&h_feature_constraints); + ParseInteractionConstraint(param.interaction_constraints, &h_feature_constraints); } catch (dmlc::Error const& e) { LOG(FATAL) << "Failed to parse feature interaction constraint:\n" << param.interaction_constraints << "\n" @@ -68,13 +63,13 @@ void FeatureInteractionConstraint::Configure( // Represent constraints as CSR format, flatten is the value vector, // ptr is row_ptr vector in CSR. - std::vector h_feature_constraints_flatten; + std::vector h_feature_constraints_flatten; for (auto const& constraints : h_feature_constraints) { - for (int32_t c : constraints) { + for (uint32_t c : constraints) { h_feature_constraints_flatten.emplace_back(c); } } - std::vector h_feature_constraints_ptr; + std::vector h_feature_constraints_ptr; size_t n_features_in_constraints = 0; h_feature_constraints_ptr.emplace_back(n_features_in_constraints); for (auto const& v : h_feature_constraints) { @@ -130,13 +125,13 @@ void FeatureInteractionConstraint::Configure( s_result_buffer_ = dh::ToSpan(result_buffer_); } -FeatureInteractionConstraint::FeatureInteractionConstraint( +FeatureInteractionConstraintDevice::FeatureInteractionConstraintDevice( tree::TrainParam const& param, int32_t const n_features) : has_constraint_{true}, n_sets_{0} { this->Configure(param, n_features); } -void FeatureInteractionConstraint::Reset() { +void FeatureInteractionConstraintDevice::Reset() { for (auto& node : node_constraints_storage_) { thrust::fill(node.begin(), node.end(), 0); } @@ -153,7 +148,7 @@ __global__ void ClearBuffersKernel( } } -void FeatureInteractionConstraint::ClearBuffers() { +void FeatureInteractionConstraintDevice::ClearBuffers() { CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size()); CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size()); uint32_t constexpr kBlockThreads = 256; @@ -164,7 +159,7 @@ void FeatureInteractionConstraint::ClearBuffers() { output_buffer_bits_, input_buffer_bits_); } -common::Span FeatureInteractionConstraint::QueryNode(int32_t node_id) { +common::Span FeatureInteractionConstraintDevice::QueryNode(int32_t node_id) { if (!has_constraint_) { return {}; } CHECK_LT(node_id, s_node_constraints_.size()); @@ -203,7 +198,7 @@ __global__ void QueryFeatureListKernel(LBitField64 node_constraints, result_buffer_output &= result_buffer_input; } -common::Span FeatureInteractionConstraint::Query( +common::Span FeatureInteractionConstraintDevice::Query( common::Span feature_list, int32_t nid) { if (!has_constraint_ || nid == 0) { return feature_list; @@ -250,8 +245,8 @@ __global__ void RestoreFeatureListFromSetsKernel( LBitField64 feature_buffer, bst_feature_t fid, - common::Span feature_interactions, - common::Span feature_interactions_ptr, // of size n interaction set + 1 + common::Span feature_interactions, + common::Span feature_interactions_ptr, // of size n interaction set + 1 common::Span interactions_list, common::Span interactions_list_ptr) { @@ -302,7 +297,7 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature, } } -void FeatureInteractionConstraint::Split( +void FeatureInteractionConstraintDevice::Split( bst_node_t node_id, bst_feature_t feature_id, bst_node_t left_id, bst_node_t right_id) { if (!has_constraint_) { return; } CHECK_NE(node_id, left_id) diff --git a/src/tree/constraints.cuh b/src/tree/constraints.cuh index 3e982a00f6dd..cd5fbf26b066 100644 --- a/src/tree/constraints.cuh +++ b/src/tree/constraints.cuh @@ -88,18 +88,18 @@ struct ValueConstraint { }; // Feature interaction constraints built for GPU Hist updater. -struct FeatureInteractionConstraint { +struct FeatureInteractionConstraintDevice { protected: // Whether interaction constraint is used. bool has_constraint_; // n interaction sets. - int32_t n_sets_; + size_t n_sets_; // The parsed feature interaction constraints as CSR. - dh::device_vector d_fconstraints_; - common::Span s_fconstraints_; - dh::device_vector d_fconstraints_ptr_; - common::Span s_fconstraints_ptr_; + dh::device_vector d_fconstraints_; + common::Span s_fconstraints_; + dh::device_vector d_fconstraints_ptr_; + common::Span s_fconstraints_ptr_; /* Interaction sets for each feature as CSR. For an input like: * [[0, 1], [1, 2]], this will have values: * @@ -141,11 +141,11 @@ struct FeatureInteractionConstraint { public: size_t Features() const; - FeatureInteractionConstraint() = default; + FeatureInteractionConstraintDevice() = default; void Configure(tree::TrainParam const& param, int32_t const n_features); - FeatureInteractionConstraint(tree::TrainParam const& param, int32_t const n_features); - FeatureInteractionConstraint(FeatureInteractionConstraint const& that) = default; - FeatureInteractionConstraint(FeatureInteractionConstraint&& that) = default; + FeatureInteractionConstraintDevice(tree::TrainParam const& param, int32_t const n_features); + FeatureInteractionConstraintDevice(FeatureInteractionConstraintDevice const& that) = default; + FeatureInteractionConstraintDevice(FeatureInteractionConstraintDevice&& that) = default; /*! \brief Reset before constructing a new tree. */ void Reset(); /*! \brief Return a list of features given node id */ diff --git a/src/tree/param.cc b/src/tree/param.cc index 8049501ea094..6f5080ee24ad 100644 --- a/src/tree/param.cc +++ b/src/tree/param.cc @@ -5,6 +5,7 @@ #include #include +#include "xgboost/json.h" #include "param.h" namespace std { @@ -79,3 +80,31 @@ std::istream &operator>>(std::istream &is, std::vector &t) { return is; } } // namespace std + +namespace xgboost { +void ParseInteractionConstraint( + std::string const &constraint_str, + std::vector> *p_out) { + auto &out = *p_out; + auto j_inc = Json::Load({constraint_str.c_str(), constraint_str.size()}); + auto const &all = get(j_inc); + out.resize(all.size()); + for (size_t i = 0; i < all.size(); ++i) { + auto const &set = get(all[i]); + for (auto const &v : set) { + if (XGBOOST_EXPECT(IsA(v), true)) { + uint32_t u = static_cast(get(v)); + out[i].emplace_back(u); + } else if (IsA(v)) { + double d = get(v); + CHECK_EQ(std::floor(d), d) + << "Found floating point number in interaction constraints"; + out[i].emplace_back(static_cast(d)); + } else { + LOG(FATAL) << "Unknown value type for interaction constraint:" + << v.GetValue().TypeStr(); + } + } + } +} +} // namespace xgboost diff --git a/src/tree/param.h b/src/tree/param.h index f8c961cf7de0..8a71cd1ef5a8 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -483,8 +483,21 @@ struct SplitEntryContainer { }; using SplitEntry = SplitEntryContainer; - } // namespace tree + +/* + * \brief Parse the interaction constraints from string. + * \param constraint_str String storing the interfaction constraints: + * + * Example input string: + * + * "[[1, 2], [3, 4]]"" + * + * \param p_out Pointer to output + */ +void ParseInteractionConstraint( + std::string const &constraint_str, + std::vector> *p_out); } // namespace xgboost // define string serializer for vector, to get the arguments diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 5ea3d4456d11..e9482b203342 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -436,7 +436,7 @@ struct GPUHistMakerDevice { common::Monitor monitor; std::vector node_value_constraints; common::ColumnSampler column_sampler; - FeatureInteractionConstraint interaction_constraints; + FeatureInteractionConstraintDevice interaction_constraints; using ExpandQueue = std::priority_queue, diff --git a/tests/cpp/tree/test_constraints.cu b/tests/cpp/tree/test_constraints.cu index f1d73df8f216..38e34ae481ae 100644 --- a/tests/cpp/tree/test_constraints.cu +++ b/tests/cpp/tree/test_constraints.cu @@ -15,12 +15,12 @@ namespace xgboost { namespace { -struct FConstraintWrapper : public FeatureInteractionConstraint { +struct FConstraintWrapper : public FeatureInteractionConstraintDevice { common::Span GetNodeConstraints() { - return FeatureInteractionConstraint::s_node_constraints_; + return FeatureInteractionConstraintDevice::s_node_constraints_; } FConstraintWrapper(tree::TrainParam param, bst_feature_t n_features) : - FeatureInteractionConstraint(param, n_features) {} + FeatureInteractionConstraintDevice(param, n_features) {} dh::device_vector const& GetDSets() const { return d_sets_;