[jvm-packages]Fix early stopping condition (#3928)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * update version * 0.82 * fix early stopping condition * remove unused * update comments * udpate comments * update test
This commit is contained in:
parent
42cac4a30b
commit
9c4ff50e83
@ -218,8 +218,15 @@ public class XGBoost {
|
|||||||
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
||||||
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
||||||
boolean onTrack = false;
|
boolean onTrack = false;
|
||||||
|
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
|
||||||
|
if (iter < earlyStoppingRounds - 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
float[] criterion = metrics[metrics.length - 1];
|
float[] criterion = metrics[metrics.length - 1];
|
||||||
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
|
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||||
|
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||||
|
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||||
|
// as `onTrack`
|
||||||
onTrack |= maximizeEvaluationMetrics ?
|
onTrack |= maximizeEvaluationMetrics ?
|
||||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||||
|
|||||||
@ -139,7 +139,7 @@ public class BoosterImplTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static class IncreasingEval implements IEvaluation {
|
private static class IncreasingEval implements IEvaluation {
|
||||||
private int value = 0;
|
private int value = 1;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getMetric() {
|
public String getMetric() {
|
||||||
@ -152,6 +152,39 @@ public class BoosterImplTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDescendMetricsWithBoundaryCondition() {
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "false");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
int totalIterations = 10;
|
||||||
|
int earlyStoppingRounds = 10;
|
||||||
|
float[][] metrics = new float[1][totalIterations];
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
|
metrics[0][i] = i;
|
||||||
|
}
|
||||||
|
for (int itr = 0; itr < totalIterations; itr++) {
|
||||||
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||||
|
itr);
|
||||||
|
if (itr == totalIterations - 1) {
|
||||||
|
TestCase.assertFalse(onTrack);
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
|
metrics[0][i] = totalIterations - i;
|
||||||
|
}
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||||
|
totalIterations - 1);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
} else {
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDescendMetrics() {
|
public void testDescendMetrics() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
@ -162,26 +195,69 @@ public class BoosterImplTest {
|
|||||||
put("maximize_evaluation_metrics", "false");
|
put("maximize_evaluation_metrics", "false");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
float[][] metrics = new float[1][5];
|
int totalIterations = 10;
|
||||||
for (int i = 0; i < 5; i++) {
|
int earlyStoppingRounds = 5;
|
||||||
|
float[][] metrics = new float[1][totalIterations];
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = i;
|
metrics[0][i] = i;
|
||||||
}
|
}
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||||
|
totalIterations - 1);
|
||||||
TestCase.assertFalse(onTrack);
|
TestCase.assertFalse(onTrack);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = 5 - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||||
|
totalIterations - 1);
|
||||||
TestCase.assertTrue(onTrack);
|
TestCase.assertTrue(onTrack);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = 5 - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
metrics[0][0] = 1;
|
metrics[0][5] = 1;
|
||||||
metrics[0][2] = 5;
|
metrics[0][6] = 2;
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
metrics[0][7] = 3;
|
||||||
|
metrics[0][8] = 4;
|
||||||
|
metrics[0][9] = 1;
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||||
|
totalIterations - 1);
|
||||||
TestCase.assertTrue(onTrack);
|
TestCase.assertTrue(onTrack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAscendMetricsWithBoundaryCondition() {
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "true");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
int totalIterations = 10;
|
||||||
|
int earlyStoppingRounds = 10;
|
||||||
|
float[][] metrics = new float[1][totalIterations];
|
||||||
|
for (int iter = 0; iter < totalIterations; iter++) {
|
||||||
|
if (iter == totalIterations - 1) {
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
|
metrics[0][i] = i;
|
||||||
|
}
|
||||||
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
|
metrics[0][i] = totalIterations - i;
|
||||||
|
}
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||||
|
TestCase.assertFalse(onTrack);
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
|
metrics[0][i] = totalIterations - i;
|
||||||
|
}
|
||||||
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAscendMetrics() {
|
public void testAscendMetrics() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
@ -192,23 +268,28 @@ public class BoosterImplTest {
|
|||||||
put("maximize_evaluation_metrics", "true");
|
put("maximize_evaluation_metrics", "true");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
float[][] metrics = new float[1][5];
|
int totalIterations = 10;
|
||||||
for (int i = 0; i < 5; i++) {
|
int earlyStoppingRounds = 5;
|
||||||
|
float[][] metrics = new float[1][totalIterations];
|
||||||
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = i;
|
metrics[0][i] = i;
|
||||||
}
|
}
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||||
TestCase.assertTrue(onTrack);
|
TestCase.assertTrue(onTrack);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = 5 - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||||
TestCase.assertFalse(onTrack);
|
TestCase.assertFalse(onTrack);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = i;
|
metrics[0][i] = i;
|
||||||
}
|
}
|
||||||
metrics[0][0] = 6;
|
metrics[0][5] = 9;
|
||||||
metrics[0][2] = 1;
|
metrics[0][6] = 8;
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
metrics[0][7] = 7;
|
||||||
|
metrics[0][8] = 6;
|
||||||
|
metrics[0][9] = 9;
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||||
TestCase.assertTrue(onTrack);
|
TestCase.assertTrue(onTrack);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,7 +318,13 @@ public class BoosterImplTest {
|
|||||||
|
|
||||||
// Make sure we've stopped early.
|
// Make sure we've stopped early.
|
||||||
for (int w = 0; w < watches.size(); w++) {
|
for (int w = 0; w < watches.size(); w++) {
|
||||||
for (int r = earlyStoppingRound + 1; r < round; r++) {
|
for (int r = 0; r < earlyStoppingRound; r++) {
|
||||||
|
TestCase.assertFalse(0.0f == metrics[w][r]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int w = 0; w < watches.size(); w++) {
|
||||||
|
for (int r = earlyStoppingRound; r < round; r++) {
|
||||||
TestCase.assertEquals(0.0f, metrics[w][r]);
|
TestCase.assertEquals(0.0f, metrics[w][r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user