Skip to content

Commit

Permalink
Revert OMP guard. (#6987)
Browse files Browse the repository at this point in the history
The guard protects the global variable from being changed by XGBoost.  But this leads to a
bug that the `n_threads` parameter is no longer used after the first iteration.  This is
due to the fact that `omp_set_num_threads` is only called once in `Learner::Configure` at
the beginning of the training process.

The guard is still useful for `gpu_id`, since this is called all the times in our codebase
doesn't matter which iteration we are currently running.
  • Loading branch information
trivialfis authored May 25, 2021
1 parent cf06a26 commit 6e52aef
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 35 deletions.
2 changes: 0 additions & 2 deletions src/c_api/c_api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ inline float GetMissing(Json const &config) {

// Safe guard some global variables from being changed by XGBoost.
class XGBoostAPIGuard {
int32_t n_threads_ {omp_get_max_threads()};
int32_t device_id_ {0};

#if defined(XGBOOST_USE_CUDA)
Expand All @@ -179,7 +178,6 @@ class XGBoostAPIGuard {
SetGPUAttribute();
}
~XGBoostAPIGuard() {
omp_set_num_threads(n_threads_);
RestoreGPUAttribute();
}
};
Expand Down
33 changes: 0 additions & 33 deletions tests/cpp/c_api/test_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,37 +278,4 @@ TEST(CAPI, XGBGlobalConfig) {
ASSERT_EQ(err.find("verbosity"), std::string::npos);
}
}

TEST(CAPI, GlobalVariables) {
size_t n_threads = omp_get_max_threads();
size_t constexpr kRows = 10;
bst_feature_t constexpr kCols = 2;

DMatrixHandle handle;
std::vector<float> data(kCols * kRows, 1.5);


ASSERT_EQ(XGDMatrixCreateFromMat_omp(data.data(), kRows, kCols,
std::numeric_limits<float>::quiet_NaN(),
&handle, 0),
0);
std::vector<float> labels(kRows, 2.0f);
ASSERT_EQ(XGDMatrixSetFloatInfo(handle, "label", labels.data(), labels.size()), 0);

DMatrixHandle m_handles[1];
m_handles[0] = handle;

BoosterHandle booster;
ASSERT_EQ(XGBoosterCreate(m_handles, 1, &booster), 0);
ASSERT_EQ(XGBoosterSetParam(booster, "nthread", "16"), 0);

omp_set_num_threads(1);
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, handle), 0);
ASSERT_EQ(omp_get_max_threads(), 1);

omp_set_num_threads(n_threads);

XGDMatrixFree(handle);
XGBoosterFree(booster);
}
} // namespace xgboost

0 comments on commit 6e52aef

Please sign in to comment.