Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages]Fix early stopping condition #3928

Merged
merged 15 commits into from
Nov 24, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,11 @@ static boolean judgeIfTrainingOnTrack(
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
boolean onTrack = false;
if (iter < earlyStoppingRounds - 1) {
return true;
}
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
onTrack |= maximizeEvaluationMetrics ?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this |= mean that, if the metric is moving in the right direction any two consecutive steps within the earlyStoppingRounds from the current iteration, then this method will return true?

This may not be what people normally expect from setting early stopping. For example, I'm getting a real training progress below with earlyStoppingSteps set to 20: the training should stop around iterations 120 since the maximum PR AUC was observed around iteration 100. But the current logic seems to look for any upward pieces within earlyStoppingSteps and keep training.
image

In the python-package(see here), training stops if the current iteration is earlyStoppingSteps away from the best iteration. Should the spark version be consistent with the python implementation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, are you interested in filing a PR or an issue?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really write in Scala. So created this issue instead.

criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void saveLoadModelWithStream() throws XGBoostError, IOException {
}

private static class IncreasingEval implements IEvaluation {
private int value = 0;
private int value = 1;

@Override
public String getMetric() {
Expand All @@ -152,6 +152,33 @@ public float eval(float[][] predicts, DMatrix dmat) {
}
}

@Test
public void testDescendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
int totalIterations = 10;
int earlyStoppingRounds = 10;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertFalse(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
}

@Test
public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
Expand All @@ -162,24 +189,57 @@ public void testDescendMetrics() {
put("maximize_evaluation_metrics", "false");
}
};
float[][] metrics = new float[1][5];
for (int i = 0; i < 5; i++) {
int totalIterations = 10;
int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
metrics[0][5] = 1;
metrics[0][6] = 2;
metrics[0][7] = 3;
metrics[0][8] = 4;
metrics[0][9] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}

@Test
public void testAscendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
int totalIterations = 10;
int earlyStoppingRounds = 10;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
metrics[0][0] = 1;
metrics[0][2] = 5;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertFalse(onTrack);
}

@Test
Expand All @@ -192,23 +252,28 @@ public void testAscendMetrics() {
put("maximize_evaluation_metrics", "true");
}
};
float[][] metrics = new float[1][5];
for (int i = 0; i < 5; i++) {
int totalIterations = 10;
int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
metrics[0][0] = 6;
metrics[0][2] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
metrics[0][5] = 9;
metrics[0][6] = 8;
metrics[0][7] = 7;
metrics[0][8] = 6;
metrics[0][9] = 9;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
}

Expand Down Expand Up @@ -237,7 +302,13 @@ public void testBoosterEarlyStop() throws XGBoostError, IOException {

// Make sure we've stopped early.
for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound + 1; r < round; r++) {
for (int r = 0; r < earlyStoppingRound; r++) {
TestCase.assertFalse(0.0f == metrics[w][r]);
}
}

for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound; r < round; r++) {
TestCase.assertEquals(0.0f, metrics[w][r]);
}
}
Expand Down