From 4ddf8d001c5706c5f745aa9b5272278980ecefa9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 13 Oct 2021 14:22:40 +0800 Subject: [PATCH] Deterministic result for element-wise/mclass metrics. (#7303) Remove openmp reduction. --- src/metric/elementwise_metric.cu | 50 ++++++++++---------- src/metric/multiclass_metric.cu | 45 ++++++++++-------- tests/cpp/metric/test_elementwise_metric.cc | 51 +++++++++++++++++++++ tests/cpp/metric/test_multiclass_metric.cc | 40 +++++++++++++++- 4 files changed, 140 insertions(+), 46 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index d8fd9e7fda5b..29130c89e4f0 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -14,6 +14,7 @@ #include "metric_common.h" #include "../common/math.h" #include "../common/common.h" +#include "../common/threading_utils.h" #if defined(XGBOOST_USE_CUDA) #include // thrust::cuda::par @@ -34,29 +35,29 @@ class ElementWiseMetricsReduction { public: explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {} - PackedReduceResult CpuReduceMetrics( - const HostDeviceVector& weights, - const HostDeviceVector& labels, - const HostDeviceVector& preds) const { + PackedReduceResult + CpuReduceMetrics(const HostDeviceVector &weights, + const HostDeviceVector &labels, + const HostDeviceVector &preds, + int32_t n_threads) const { size_t ndata = labels.Size(); const auto& h_labels = labels.HostVector(); const auto& h_weights = weights.HostVector(); const auto& h_preds = preds.HostVector(); - bst_float residue_sum = 0; - bst_float weights_sum = 0; - - dmlc::OMPException exc; -#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static) - for (omp_ulong i = 0; i < ndata; ++i) { - exc.Run([&]() { - const bst_float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f; - residue_sum += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; - weights_sum += wt; - }); - } - exc.Rethrow(); + std::vector score_tloc(n_threads, 0.0); + std::vector weight_tloc(n_threads, 0.0); + + common::ParallelFor(ndata, n_threads, [&](size_t i) { + float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f; + auto t_idx = omp_get_thread_num(); + score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; + weight_tloc[t_idx] += wt; + }); + double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); + double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); + PackedReduceResult res { residue_sum, weights_sum }; return res; } @@ -100,19 +101,19 @@ class ElementWiseMetricsReduction { #endif // XGBOOST_USE_CUDA PackedReduceResult Reduce( - const GenericParameter &tparam, - int device, + const GenericParameter &ctx, const HostDeviceVector& weights, const HostDeviceVector& labels, const HostDeviceVector& preds) { PackedReduceResult result; - if (device < 0) { - result = CpuReduceMetrics(weights, labels, preds); + if (ctx.gpu_id < 0) { + auto n_threads = ctx.Threads(); + result = CpuReduceMetrics(weights, labels, preds, n_threads); } #if defined(XGBOOST_USE_CUDA) else { // NOLINT - device_ = device; + device_ = ctx.gpu_id; preds.SetDevice(device_); labels.SetDevice(device_); weights.SetDevice(device_); @@ -365,10 +366,7 @@ struct EvalEWiseBase : public Metric { CHECK_EQ(preds.Size(), info.labels_.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - int device = tparam_->gpu_id; - - auto result = - reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds); + auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_, preds); double dat[2] { result.Residue(), result.Weights() }; diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 3b20361b0fd5..580edf4532df 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -6,11 +6,14 @@ */ #include #include + +#include #include #include "metric_common.h" #include "../common/math.h" #include "../common/common.h" +#include "../common/threading_utils.h" #if defined(XGBOOST_USE_CUDA) #include // thrust::cuda::par @@ -37,38 +40,41 @@ class MultiClassMetricsReduction { public: MultiClassMetricsReduction() = default; - PackedReduceResult CpuReduceMetrics( - const HostDeviceVector& weights, - const HostDeviceVector& labels, - const HostDeviceVector& preds, - const size_t n_class) const { + PackedReduceResult + CpuReduceMetrics(const HostDeviceVector &weights, + const HostDeviceVector &labels, + const HostDeviceVector &preds, + const size_t n_class, int32_t n_threads) const { size_t ndata = labels.Size(); const auto& h_labels = labels.HostVector(); const auto& h_weights = weights.HostVector(); const auto& h_preds = preds.HostVector(); - bst_float residue_sum = 0; - bst_float weights_sum = 0; - int label_error = 0; + std::atomic label_error {0}; bool const is_null_weight = weights.Size() == 0; - dmlc::OMPException exc; -#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static) - for (omp_ulong idx = 0; idx < ndata; ++idx) { - exc.Run([&]() { + std::vector scores_tloc(n_threads, 0); + std::vector weights_tloc(n_threads, 0); + common::ParallelFor(ndata, n_threads, [&](size_t idx) { bst_float weight = is_null_weight ? 1.0f : h_weights[idx]; auto label = static_cast(h_labels[idx]); if (label >= 0 && label < static_cast(n_class)) { - residue_sum += EvalRowPolicy::EvalRow( - label, h_preds.data() + idx * n_class, n_class) * weight; - weights_sum += weight; + auto t_idx = omp_get_thread_num(); + scores_tloc[t_idx] += + EvalRowPolicy::EvalRow(label, h_preds.data() + idx * n_class, + n_class) * + weight; + weights_tloc[t_idx] += weight; } else { label_error = label; } - }); - } - exc.Rethrow(); + }); + + double residue_sum = + std::accumulate(scores_tloc.cbegin(), scores_tloc.cend(), 0.0); + double weights_sum = + std::accumulate(weights_tloc.cbegin(), weights_tloc.cend(), 0.0); CheckLabelError(label_error, n_class); PackedReduceResult res { residue_sum, weights_sum }; @@ -131,7 +137,8 @@ class MultiClassMetricsReduction { PackedReduceResult result; if (device < 0) { - result = CpuReduceMetrics(weights, labels, preds, n_class); + result = + CpuReduceMetrics(weights, labels, preds, n_class, tparam.Threads()); } #if defined(XGBOOST_USE_CUDA) else { // NOLINT diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index b26c72d2e3c4..dfb188b6be6f 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -2,9 +2,44 @@ * Copyright 2018-2019 XGBoost contributors */ #include +#include + #include +#include + #include "../helpers.h" +namespace xgboost { +namespace { +inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) { + auto lparam = CreateEmptyGenericParam(device); + std::unique_ptr metric{Metric::Create(name.c_str(), &lparam)}; + + HostDeviceVector predts; + MetaInfo info; + auto &h_labels = info.labels_.HostVector(); + auto &h_predts = predts.HostVector(); + + SimpleLCG lcg; + SimpleRealUniformDistribution dist{0.0f, 1.0f}; + + size_t n_samples = 2048; + h_labels.resize(n_samples); + h_predts.resize(n_samples); + + for (size_t i = 0; i < n_samples; ++i) { + h_predts[i] = dist(&lcg); + h_labels[i] = dist(&lcg); + } + + auto result = metric->Eval(predts, info, false); + for (size_t i = 0; i < 8; ++i) { + ASSERT_EQ(metric->Eval(predts, info, false), result); + } +} +} // anonymous namespace +} // namespace xgboost + TEST(Metric, DeclareUnifiedTest(RMSE)) { auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam); @@ -26,6 +61,8 @@ TEST(Metric, DeclareUnifiedTest(RMSE)) { { 1, 2, 9, 8}), 0.6708f, 0.001f); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"rmse"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(RMSLE)) { @@ -49,6 +86,8 @@ TEST(Metric, DeclareUnifiedTest(RMSLE)) { { 0, 1, 2, 9, 8}), 0.2415f, 1e-4); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"rmsle"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(MAE)) { @@ -72,6 +111,8 @@ TEST(Metric, DeclareUnifiedTest(MAE)) { { 1, 2, 9, 8}), 0.54f, 0.001f); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mae"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(MAPE)) { @@ -95,6 +136,8 @@ TEST(Metric, DeclareUnifiedTest(MAPE)) { { 1, 2, 9, 8}), 1.3250f, 0.001f); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mape"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(MPHE)) { @@ -118,6 +161,8 @@ TEST(Metric, DeclareUnifiedTest(MPHE)) { { 1, 2, 9, 8}), 0.1922f, 1e-4); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(LogLoss)) { @@ -145,6 +190,8 @@ TEST(Metric, DeclareUnifiedTest(LogLoss)) { { 1, 2, 9, 8}), 1.3138f, 0.001f); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"logloss"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(Error)) { @@ -197,6 +244,8 @@ TEST(Metric, DeclareUnifiedTest(Error)) { { 1, 2, 9, 8}), 0.45f, 0.001f); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"error@0.5"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { @@ -224,4 +273,6 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { { 1, 2, 9, 8}), 1.5783f, 0.001f); delete metric; + + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); } diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc index 95f2c417e946..6f8ff28094cd 100644 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ b/tests/cpp/metric/test_multiclass_metric.cc @@ -4,6 +4,43 @@ #include "../helpers.h" +namespace xgboost { +inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) { + auto lparam = CreateEmptyGenericParam(device); + std::unique_ptr metric{Metric::Create(name.c_str(), &lparam)}; + + HostDeviceVector predts; + MetaInfo info; + auto &h_labels = info.labels_.HostVector(); + auto &h_predts = predts.HostVector(); + + SimpleLCG lcg; + + size_t n_samples = 2048, n_classes = 4; + h_labels.resize(n_samples); + h_predts.resize(n_samples * n_classes); + + { + SimpleRealUniformDistribution dist{0.0f, static_cast(n_classes)}; + for (size_t i = 0; i < n_samples; ++i) { + h_labels[i] = dist(&lcg); + } + } + + { + SimpleRealUniformDistribution dist{0.0f, 1.0f}; + for (size_t i = 0; i < n_samples * n_classes; ++i) { + h_predts[i] = dist(&lcg); + } + } + + auto result = metric->Eval(predts, info, false); + for (size_t i = 0; i < 8; ++i) { + ASSERT_EQ(metric->Eval(predts, info, false), result); + } +} +} // namespace xgboost + inline void TestMultiClassError(int device) { auto lparam = xgboost::CreateEmptyGenericParam(device); lparam.gpu_id = device; @@ -17,12 +54,12 @@ inline void TestMultiClassError(int device) { {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}, {0, 1, 2}), 0.666f, 0.001f); - delete metric; } TEST(Metric, DeclareUnifiedTest(MultiClassError)) { TestMultiClassError(GPUIDX); + xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"merror"}, GPUIDX); } inline void TestMultiClassLogLoss(int device) { @@ -44,6 +81,7 @@ inline void TestMultiClassLogLoss(int device) { TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { TestMultiClassLogLoss(GPUIDX); + xgboost::CheckDeterministicMetricMultiClass(xgboost::StringView{"mlogloss"}, GPUIDX); } #if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)