Skip to content

Commit

Permalink
allow caller to place the missing bin at the high end of the feature …
Browse files Browse the repository at this point in the history
…values
  • Loading branch information
paulbkoch committed Dec 26, 2024
1 parent 42875bc commit 5e999b6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 26 deletions.
87 changes: 62 additions & 25 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
const size_t iDimension,
const Bin<FloatMain, UIntMain, true, true, bHessian>* const* const apBins,
const TreeNode<bHessian>* pMissingValueTreeNode,
const size_t cSlices
#ifndef NDEBUG
,
const size_t cBins
#endif // NDEBUG
) {
const size_t cSlices,
const size_t cBins) {
LOG_0(Trace_Verbose, "Entered Flatten");

EBM_ASSERT(nullptr != pBoosterShell);
Expand Down Expand Up @@ -178,6 +174,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
pUpdateScore = aUpdateScore;

if(bMissing) {
EBM_ASSERT(2 <= cSlices); // no cuts if there was only missing bin

// always put a split on the missing bin
*pSplit = 1;
++pSplit;
Expand All @@ -199,6 +197,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);

TreeNode<bHessian>* pParent = nullptr;
bool bDone = false;

while(true) {
if(UNPREDICTABLE(pTreeNode->AFTER_IsSplit())) {
Expand Down Expand Up @@ -253,11 +252,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
}
EBM_ASSERT(!bNominal);

// if !bNominal, check the bin above and below for order
EBM_ASSERT(apBins == ppBinLast || *(ppBinLast - 1) < *ppBinLast);
EBM_ASSERT(ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{2} : size_t{1})) ||
*ppBinLast < *(ppBinLast + 1));

iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);

while(true) { // not a real loop
Expand All @@ -267,8 +261,17 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
pMissingBin = pTreeNode->GetBin();
}
if(1 == iEdge) {
// this cut would isolate the missing bin, but we handle those scores separately
break;
}
} else if(TermBoostFlags_MissingHigh & flags) {
++iEdge; // missing is at index 0 in the model, so we are offset by one
pMissingBin = pTreeNode->GetBin();
EBM_ASSERT(iEdge <= cBins + 1);
if(bDone) {
// this cut would isolate the missing bin, but we handle those scores separately
goto done;
}
}
}

Expand All @@ -290,6 +293,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
deltaStepMax);
}

EBM_ASSERT(pUpdateScore < aUpdateScore + cScores * cSlices);
*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;

Expand All @@ -316,10 +320,27 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,

while(true) {
if(nullptr == pTreeNode) {
done:;
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);

EBM_ASSERT(bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);

EBM_ASSERT(bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));

#ifndef NDEBUG
UIntSplit prevDebug = 0;
for(size_t iDebug = 0; iDebug < cSlices - 1; ++iDebug) {
UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer(iDimension)[iDebug];
EBM_ASSERT(prevDebug < curDebug);
prevDebug = curDebug;
}
EBM_ASSERT(prevDebug < cBins);
#endif

EBM_ASSERT(nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
if(nullptr != pMissingBin) {
EBM_ASSERT(bMissing);

FloatScore hess = static_cast<FloatCalc>(pMissingBin->GetWeight());
const auto* pGradientPair = pMissingBin->GetGradientPairs();
const auto* const pGradientPairEnd = pGradientPair + cScores;
Expand Down Expand Up @@ -353,12 +374,22 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
if(bMissing) {
if(TermBoostFlags_MissingLow & flags) {
if(1 == iEdge) {
// this cut would isolate the missing bin, but missing already has a cut
break;
}
} else if(TermBoostFlags_MissingHigh & flags) {
EBM_ASSERT(iEdge <= cBins);
if(cBins == iEdge) {
// This cut would isolate the missing bin, but missing already has a cut.
// We still need to find the missing bin though in the tree, so continue.
bDone = true;
break;
}
}
}

EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
EBM_ASSERT(pSplit < cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));
*pSplit = static_cast<UIntSplit>(iEdge);
++pSplit;

Expand Down Expand Up @@ -865,13 +896,14 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo

const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
if(TermBoostFlags_MissingLow & flags) {
if(bMissing) {
if(!bNominal) {
pMissingBin = pBin;
}
*ppBin = pBin;
if(bMissing && !bNominal) {
pMissingBin = pBin;
}
} else if(TermBoostFlags_MissingHigh & flags) {
if(bMissing && !bNominal) {
pMissingBin = pBin;
// the concept of TermBoostFlags_MissingHigh does not exist for nominals
pBin = IndexBin(pBin, cBytesPerBin);
++ppBin;
}
} else {
if(bMissing) {
Expand All @@ -888,6 +920,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
++ppBin;
} while(pBinsEnd != pBin);

if(TermBoostFlags_MissingHigh & flags) {
if(bMissing && !bNominal) {
*ppBin = aBins;
++ppBin;
}
}

if(bNominal) {
std::sort(apBins,
ppBin,
Expand Down Expand Up @@ -1072,15 +1111,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
iDimension,
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian>* const*>(apBins),
nullptr != pMissingValueTreeNode ? pMissingValueTreeNode->Downgrade() : nullptr,
cSlices
#ifndef NDEBUG
,
cBins
#endif // NDEBUG
);
cSlices,
cBins);

EBM_ASSERT(!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension));
EBM_ASSERT(!bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1);
EBM_ASSERT(
error != Error_None || !bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension));
EBM_ASSERT(
error != Error_None || !bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1);

return error;
}
Expand Down
2 changes: 1 addition & 1 deletion shared/libebm/tests/boosting_unusual_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
}

TEST_CASE("stress test, boosting") {
const double expected = 26758407585917.129;
const double expected = 26746562197367.172;

double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
CHECK(validationMetricExact == expected);
Expand Down

0 comments on commit 5e999b6

Please sign in to comment.