From 14bd864fc0f02ae1719ab49dd6226a91d37bb576 Mon Sep 17 00:00:00 2001 From: hzy46 Date: Fri, 22 Oct 2021 15:02:05 +0800 Subject: [PATCH 1/8] fix fix fix fix fix fix fix fix fix fix --- include/LightGBM/utils/openmp_wrapper.h | 6 ++++ src/application/application.cpp | 2 ++ src/c_api.cpp | 38 +++++++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index b48113e10ba4..2101837e3ba8 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -25,6 +25,12 @@ inline int OMP_NUM_THREADS() { return ret; } +const static int default_omp_num_threads = OMP_NUM_THREADS(); + +inline void omp_reset_num_threads() { + omp_set_num_threads(default_omp_num_threads); +} + class ThreadExceptionHelper { public: ThreadExceptionHelper() { diff --git a/src/application/application.cpp b/src/application/application.cpp index d9a4d7544ebc..b20c571efee9 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -33,6 +33,8 @@ Application::Application(int argc, char** argv) { // set number of threads for openmp if (config_.num_threads > 0) { omp_set_num_threads(config_.num_threads); + } else { + omp_reset_num_threads(); } if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) { Log::Fatal("No training/prediction data, application quit"); diff --git a/src/c_api.cpp b/src/c_api.cpp index bc3bfc3b2434..19eddb628868 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -115,6 +115,8 @@ class Booster { config_.Set(param); if (config_.num_threads > 0) { omp_set_num_threads(config_.num_threads); + } else { + omp_reset_num_threads(); } // create boosting if (config_.input_model.size() > 0) { @@ -316,6 +318,8 @@ class Booster { if (config_.num_threads > 0) { omp_set_num_threads(config_.num_threads); + } else { + omp_reset_num_threads(); } if (param.count("objective")) { @@ -953,6 +957,8 @@ int LGBM_DatasetCreateFromFile(const char* filename, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } DatasetLoader loader(config, nullptr, 1, filename); if (reference == nullptr) { @@ -983,6 +989,8 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } DatasetLoader loader(config, nullptr, 1, nullptr); *out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col, @@ -1098,6 +1106,8 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } std::unique_ptr ret; int32_t total_nrow = 0; @@ -1190,6 +1200,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } std::unique_ptr ret; auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); @@ -1258,6 +1270,8 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } std::unique_ptr ret; int32_t nrow = num_rows; @@ -1330,6 +1344,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } std::unique_ptr ret; int32_t nrow = static_cast(num_row); @@ -1411,6 +1427,8 @@ int LGBM_DatasetGetSubset( config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } auto full_dataset = reinterpret_cast(handle); CHECK_GT(num_used_row_indices, 0); @@ -1818,6 +1836,8 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } Booster* ref_booster = reinterpret_cast(handle); ref_booster->Predict(start_iteration, num_iteration, predict_type, data_filename, data_has_header, @@ -1896,6 +1916,8 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); @@ -1930,6 +1952,8 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } if (matrix_type == C_API_MATRIX_TYPE_CSR) { if (num_col_or_row <= 0) { @@ -2017,6 +2041,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); @@ -2049,6 +2075,8 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, if (fastConfig_ptr->config.num_threads > 0) { omp_set_num_threads(fastConfig_ptr->config.num_threads); + } else { + omp_reset_num_threads(); } fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config); @@ -2097,6 +2125,8 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } int num_threads = OMP_NUM_THREADS(); int ncol = static_cast(ncol_ptr - 1); @@ -2142,6 +2172,8 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); @@ -2167,6 +2199,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); @@ -2193,6 +2227,8 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, if (fastConfig_ptr->config.num_threads > 0) { omp_set_num_threads(fastConfig_ptr->config.num_threads); + } else { + omp_reset_num_threads(); } fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config); @@ -2233,6 +2269,8 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, config.Set(param); if (config.num_threads > 0) { omp_set_num_threads(config.num_threads); + } else { + omp_reset_num_threads(); } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type); From 82b3fe267c5ee97fb7a04d820c7ba15ff9f82593 Mon Sep 17 00:00:00 2001 From: hzy46 Date: Fri, 22 Oct 2021 16:10:45 +0800 Subject: [PATCH 2/8] fix --- include/LightGBM/utils/openmp_wrapper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index 2101837e3ba8..bb3464760241 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -25,7 +25,7 @@ inline int OMP_NUM_THREADS() { return ret; } -const static int default_omp_num_threads = OMP_NUM_THREADS(); +static const int default_omp_num_threads = OMP_NUM_THREADS(); inline void omp_reset_num_threads() { omp_set_num_threads(default_omp_num_threads); From 1222a7c09cebbe1742a42845b7e1f67f978f7522 Mon Sep 17 00:00:00 2001 From: hzy46 Date: Fri, 22 Oct 2021 17:11:51 +0800 Subject: [PATCH 3/8] mock func for no openmp --- include/LightGBM/utils/openmp_wrapper.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index bb3464760241..82811b36e7e4 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -100,6 +100,7 @@ class ThreadExceptionHelper { simulate a single thread running. All #pragma omp should be ignored by the compiler **/ inline void omp_set_num_threads(int) __GOMP_NOTHROW {} // NOLINT (no cast done here) + inline void omp_reset_num_threads() __GOMP_NOTHROW {} inline int omp_get_num_threads() __GOMP_NOTHROW {return 1;} inline int omp_get_max_threads() __GOMP_NOTHROW {return 1;} inline int omp_get_thread_num() __GOMP_NOTHROW {return 0;} From 1598f97ce5642776b3c381166c4eb57bb5f03c0a Mon Sep 17 00:00:00 2001 From: hzy46 Date: Tue, 26 Oct 2021 16:36:25 +0800 Subject: [PATCH 4/8] fix --- include/LightGBM/utils/openmp_wrapper.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index 82811b36e7e4..04a2878ad1a0 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -25,7 +25,8 @@ inline int OMP_NUM_THREADS() { return ret; } -static const int default_omp_num_threads = OMP_NUM_THREADS(); +static const int default_omp_num_threads = omp_get_num_threads(); + inline void omp_reset_num_threads() { omp_set_num_threads(default_omp_num_threads); From f8cd199b899db66ca111b3723514b537776660ff Mon Sep 17 00:00:00 2001 From: hzy46 Date: Tue, 26 Oct 2021 16:41:32 +0800 Subject: [PATCH 5/8] minor --- include/LightGBM/utils/openmp_wrapper.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index 04a2878ad1a0..4a0043b5b92e 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -27,7 +27,6 @@ inline int OMP_NUM_THREADS() { static const int default_omp_num_threads = omp_get_num_threads(); - inline void omp_reset_num_threads() { omp_set_num_threads(default_omp_num_threads); } From 2904014121970d78b9cab1d6a7d3445c559f26dd Mon Sep 17 00:00:00 2001 From: hzy46 Date: Tue, 26 Oct 2021 17:36:16 +0800 Subject: [PATCH 6/8] use omp_get_max_threads --- include/LightGBM/utils/openmp_wrapper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index 4a0043b5b92e..d01dd60784fa 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -25,7 +25,7 @@ inline int OMP_NUM_THREADS() { return ret; } -static const int default_omp_num_threads = omp_get_num_threads(); +static const int default_omp_num_threads = omp_get_max_threads(); inline void omp_reset_num_threads() { omp_set_num_threads(default_omp_num_threads); From 44f0c36469e45d1d7fee74f8d80a52ece940f987 Mon Sep 17 00:00:00 2001 From: hzy46 Date: Thu, 28 Oct 2021 14:42:59 +0800 Subject: [PATCH 7/8] fix --- include/LightGBM/utils/openmp_wrapper.h | 14 +-- src/application/application.cpp | 6 +- src/c_api.cpp | 114 ++++-------------------- 3 files changed, 29 insertions(+), 105 deletions(-) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index d01dd60784fa..b826e79e055a 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -25,10 +25,14 @@ inline int OMP_NUM_THREADS() { return ret; } -static const int default_omp_num_threads = omp_get_max_threads(); - -inline void omp_reset_num_threads() { - omp_set_num_threads(default_omp_num_threads); +inline void OMP_SET_NUM_THREADS(int num_threads) { + static const int default_omp_num_threads = OMP_NUM_THREADS(); + if (num_threads > 0) { + omp_set_num_threads(num_threads); + } + else { + omp_set_num_threads(default_omp_num_threads); + } } class ThreadExceptionHelper { @@ -100,7 +104,7 @@ class ThreadExceptionHelper { simulate a single thread running. All #pragma omp should be ignored by the compiler **/ inline void omp_set_num_threads(int) __GOMP_NOTHROW {} // NOLINT (no cast done here) - inline void omp_reset_num_threads() __GOMP_NOTHROW {} + inline void OMP_SET_NUM_THREADS(int) __GOMP_NOTHROW {} inline int omp_get_num_threads() __GOMP_NOTHROW {return 1;} inline int omp_get_max_threads() __GOMP_NOTHROW {return 1;} inline int omp_get_thread_num() __GOMP_NOTHROW {return 0;} diff --git a/src/application/application.cpp b/src/application/application.cpp index b20c571efee9..b7f55a2ec0e4 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -31,11 +31,7 @@ namespace LightGBM { Application::Application(int argc, char** argv) { LoadParameters(argc, argv); // set number of threads for openmp - if (config_.num_threads > 0) { - omp_set_num_threads(config_.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config_.num_threads); if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) { Log::Fatal("No training/prediction data, application quit"); } diff --git a/src/c_api.cpp b/src/c_api.cpp index 19eddb628868..ec3880b60675 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -113,11 +113,7 @@ class Booster { const char* parameters) { auto param = Config::Str2Map(parameters); config_.Set(param); - if (config_.num_threads > 0) { - omp_set_num_threads(config_.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config_.num_threads); // create boosting if (config_.input_model.size() > 0) { Log::Warning("Continued train from model is not supported for c_api,\n" @@ -316,11 +312,7 @@ class Booster { config_.Set(param); - if (config_.num_threads > 0) { - omp_set_num_threads(config_.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config_.num_threads); if (param.count("objective")) { // create objective function @@ -955,11 +947,7 @@ int LGBM_DatasetCreateFromFile(const char* filename, auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); DatasetLoader loader(config, nullptr, 1, filename); if (reference == nullptr) { if (Network::num_machines() == 1) { @@ -987,11 +975,7 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data, auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); DatasetLoader loader(config, nullptr, 1, nullptr); *out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col, num_sample_row, @@ -1104,11 +1088,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); std::unique_ptr ret; int32_t total_nrow = 0; for (int j = 0; j < nmat; ++j) { @@ -1198,11 +1178,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); std::unique_ptr ret; auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); @@ -1268,11 +1244,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); std::unique_ptr ret; int32_t nrow = num_rows; if (reference == nullptr) { @@ -1342,11 +1314,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); std::unique_ptr ret; int32_t nrow = static_cast(num_row); if (reference == nullptr) { @@ -1425,11 +1393,7 @@ int LGBM_DatasetGetSubset( auto param = Config::Str2Map(parameters); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); auto full_dataset = reinterpret_cast(handle); CHECK_GT(num_used_row_indices, 0); const int32_t lower = 0; @@ -1834,11 +1798,7 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); ref_booster->Predict(start_iteration, num_iteration, predict_type, data_filename, data_has_header, config, result_filename); @@ -1914,11 +1874,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int nrow = static_cast(nindptr - 1); @@ -1950,11 +1906,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); if (matrix_type == C_API_MATRIX_TYPE_CSR) { if (num_col_or_row <= 0) { Log::Fatal("The number of columns should be greater than zero."); @@ -2039,11 +1991,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config); @@ -2073,11 +2021,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, data_type, static_cast(num_col))); - if (fastConfig_ptr->config.num_threads > 0) { - omp_set_num_threads(fastConfig_ptr->config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(fastConfig_ptr->config.num_threads); fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config); @@ -2123,11 +2067,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); int num_threads = OMP_NUM_THREADS(); int ncol = static_cast(ncol_ptr - 1); std::vector> iterators(num_threads, std::vector()); @@ -2170,11 +2110,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun, @@ -2197,11 +2133,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config); @@ -2225,11 +2157,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, data_type, ncol)); - if (fastConfig_ptr->config.num_threads > 0) { - omp_set_num_threads(fastConfig_ptr->config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(fastConfig_ptr->config.num_threads); fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config); @@ -2267,11 +2195,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, auto param = Config::Str2Map(parameter); Config config; config.Set(param); - if (config.num_threads > 0) { - omp_set_num_threads(config.num_threads); - } else { - omp_reset_num_threads(); - } + OMP_SET_NUM_THREADS(config.num_threads); Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type); ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); From 5596a477085fe7f00efde3d3281f96deb69816da Mon Sep 17 00:00:00 2001 From: hzy46 Date: Thu, 28 Oct 2021 14:55:44 +0800 Subject: [PATCH 8/8] fix --- include/LightGBM/utils/openmp_wrapper.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index b826e79e055a..a337fc353b75 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -29,8 +29,7 @@ inline void OMP_SET_NUM_THREADS(int num_threads) { static const int default_omp_num_threads = OMP_NUM_THREADS(); if (num_threads > 0) { omp_set_num_threads(num_threads); - } - else { + } else { omp_set_num_threads(default_omp_num_threads); } }