Skip to content

Commit

Permalink
fix inconsistent results in ARM exact tests by removing purification …
Browse files Browse the repository at this point in the history
…from tests
  • Loading branch information
paulbkoch committed Dec 23, 2024
1 parent ac6452e commit 6a76e04
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions shared/libebm/tests/boosting_unusual_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2088,9 +2088,11 @@ TEST_CASE("stress test, boosting") {
// terms.push_back({0, 1, 2, 3}); // TODO: enable when fast enough
}
const size_t cRounds = 200;
std::vector<IntEbm> boostFlagsAny{TermBoostFlags_PurifyGain,
std::vector<IntEbm> boostFlagsAny{// TermBoostFlags_PurifyGain,
TermBoostFlags_DisableNewtonGain,
TermBoostFlags_DisableCategorical,
// TermBoostFlags_PurifyUpdate,
// TermBoostFlags_GradientSums, // does not return a metric
TermBoostFlags_DisableNewtonUpdate,
TermBoostFlags_RandomSplits};
std::vector<IntEbm> boostFlagsChoose{TermBoostFlags_Default,
Expand All @@ -2099,10 +2101,10 @@ TEST_CASE("stress test, boosting") {
TermBoostFlags_MissingSeparate,
TermBoostFlags_MissingDrop};

double validationMetric = 0.0;
double validationMetric = 1.0;

for(IntEbm classesCount = Task_Regression; classesCount < 5; ++classesCount) {
if(classesCount != Task_Regression && classesCount < 2) {
if(classesCount != Task_Regression && classesCount < 1) {
continue;
}
const auto train = MakeRandomDataset(rng, classesCount, cTrainSamples, features);
Expand Down Expand Up @@ -2159,9 +2161,13 @@ TEST_CASE("stress test, boosting") {
.validationMetric;
}
}
validationMetric += validationMetricIteration;
if(classesCount == 1) {
CHECK(-std::numeric_limits<double>::infinity() == validationMetric);
} else {
validationMetric *= validationMetricIteration;
}
}
}

CHECK(validationMetric == 42031.143270308334);
CHECK(validationMetric == 62013566170252.117);
}

0 comments on commit 6a76e04

Please sign in to comment.