Skip to content

Commit

Permalink
simplify single feature tree boosting code
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 27, 2024
1 parent d7d7d1b commit 4171091
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,16 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
cSamplesTotalDebug += static_cast<size_t>(pTreeNode->GetBin()->GetCountSamples());
#endif // NDEBUG

const auto* const pGradientPairEnd = pTreeNode->GetBin()->GetGradientPairs() + cScores;
size_t iEdge;
const auto* const aGradientPair = pTreeNode->GetBin()->GetGradientPairs();
size_t iScore;
FloatScore hess;
const GradientPair<FloatMain, bHessian>* pGradientPair;
if(nullptr != ppBinCur) {
goto determine_bin;
}
EBM_ASSERT(!bNominal);

iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);
iEdge = ppBinLast - apBins + 1;

while(true) { // not a real loop
if(bMissing) {
Expand All @@ -270,47 +271,41 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
// this cut would isolate the missing bin, but we handle those scores separately
break;
}
} else if(TermBoostFlags_MissingHigh & flags) {
} else {
++iEdge; // missing is at index 0 in the model, so we are offset by one
pMissingBin = pTreeNode->GetBin();
EBM_ASSERT(iEdge <= cBins + 1);
if(bDone) {
// this cut would isolate the missing bin, but we handle those scores separately
goto done;
if(TermBoostFlags_MissingHigh & flags) {
pMissingBin = pTreeNode->GetBin();
EBM_ASSERT(iEdge <= cBins + 1);
if(bDone) {
// this cut would isolate the missing bin, but we handle those scores separately
goto done;
}
}
} else if(TermBoostFlags_MissingSeparate & flags) {
++iEdge; // missing is at index 0 in the model, so we are offset by one
}
}

while(true) {
iScore = 0;
hess = static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight());
pGradientPair = pTreeNode->GetBin()->GetGradientPairs();
do {
FloatCalc updateScore;
if(bUpdateWithHessian) {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
regAlpha,
regLambda,
deltaStepMax);
} else {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
regAlpha,
regLambda,
deltaStepMax);
hess = static_cast<FloatCalc>(pGradientPair->GetHess());
}
const FloatCalc updateScore = -CalcNegUpdate<true>(
static_cast<FloatCalc>(pGradientPair->m_sumGradients), hess, regAlpha, regLambda, deltaStepMax);

EBM_ASSERT(pUpdateScore < aUpdateScore + cScores * cSlices);
*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;

++iScore;
} while(cScores != iScore);
++pGradientPair;
} while(pGradientPairEnd != pGradientPair);

if(nullptr == ppBinCur) {
break;
}
EBM_ASSERT(bNominal);

++ppBinCur;
if(ppBinLast < ppBinCur) {
break;
Expand Down Expand Up @@ -372,13 +367,12 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
}

done:;
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);

#ifndef NDEBUG
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);
EBM_ASSERT(bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);

EBM_ASSERT(bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));

#ifndef NDEBUG
UIntSplit prevDebug = 0;
for(size_t iDebug = 0; iDebug < cSlices - 1; ++iDebug) {
UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer(iDimension)[iDebug];
Expand All @@ -388,7 +382,6 @@ done:;
EBM_ASSERT(prevDebug < cBins);
#endif

EBM_ASSERT(nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
if(nullptr != pMissingBin) {
EBM_ASSERT(bMissing);

Expand All @@ -403,11 +396,13 @@ done:;
FloatCalc updateScore = -CalcNegUpdate<true>(
static_cast<FloatCalc>(pGradientPair->m_sumGradients), hess, regAlpha, regLambda, deltaStepMax);

*pMissingUpdateScore = updateScore;
*pMissingUpdateScore = static_cast<FloatScore>(updateScore);
++pMissingUpdateScore;

++pGradientPair;
} while(pGradientPairEnd != pGradientPair);
} else {
EBM_ASSERT(!bMissing || bNominal && (TermBoostFlags_MissingCategory & flags));
}

LOG_0(Trace_Verbose, "Exited Flatten");
Expand Down Expand Up @@ -903,7 +898,7 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
const auto* aSumBins = aBins;
if(bNominal) {
if(TermBoostFlags_MissingCategory & flags) {
// nothing to do
// Nothing to do. Treat missing like any other category.
} else {
if(bMissing) {
pMissingValueTreeNode = pRootTreeNode;
Expand Down

0 comments on commit 4171091

Please sign in to comment.