[jvm-packages] Fix early stop with xgboost4j-spark (#4176)
* Fix early stop with xgboost4j-spark * Update XGBoost.java * Update XGBoost.java * Update XGBoost.java To use -Float.MAX_VALUE as the lower bound, in case there is positive metric. * Only update best score if the current score is better (no update when equal) * Update xgboost-spark tutorial to fix early stopping docs.
This commit is contained in:
@@ -140,6 +140,8 @@ public class XGBoost {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
float bestScore;
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
|
||||
@@ -150,6 +152,12 @@ public class XGBoost {
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||
|
||||
//collect all data matrixs
|
||||
@@ -196,12 +204,27 @@ public class XGBoost {
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
|
||||
// If there is more than one evaluation datasets, the last one would be used
|
||||
// to determinate early stop.
|
||||
float score = metricsOut[metricsOut.length - 1];
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
// Update best score if the current score is better (no update when equal)
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
bestIteration = iter;
|
||||
}
|
||||
} else {
|
||||
if (score < bestScore) {
|
||||
bestScore = score;
|
||||
bestIteration = iter;
|
||||
}
|
||||
}
|
||||
if (earlyStoppingRounds > 0) {
|
||||
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
|
||||
if (!onTrack) {
|
||||
String reversedDirection = getReversedDirection(params);
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -214,42 +237,11 @@ public class XGBoost {
|
||||
return booster;
|
||||
}
|
||||
|
||||
static boolean judgeIfTrainingOnTrack(
|
||||
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
||||
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
||||
boolean onTrack = false;
|
||||
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
|
||||
if (iter < earlyStoppingRounds - 1) {
|
||||
return true;
|
||||
}
|
||||
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
|
||||
float[] criterion = metrics[metricsId];
|
||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||
// as `onTrack`
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
if (!onTrack) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return onTrack;
|
||||
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
|
||||
return iter - bestIteration >= earlyStoppingRounds;
|
||||
}
|
||||
|
||||
private static String getReversedDirection(Map<String, Object> params) {
|
||||
String reversedDirection = null;
|
||||
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||
reversedDirection = "descending";
|
||||
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||
reversedDirection = "ascending";
|
||||
}
|
||||
return reversedDirection;
|
||||
}
|
||||
|
||||
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
|
||||
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
||||
try {
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
assert(maximize != null);
|
||||
|
||||
@@ -154,188 +154,159 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testDescendMetricsWithBoundaryCondition() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
int totalIterations = 10;
|
||||
int earlyStoppingRounds = 10;
|
||||
// maximize_evaluation_metrics = false
|
||||
int totalIterations = 11;
|
||||
int earlyStoppingRound = 10;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
int bestIteration = 0;
|
||||
|
||||
for (int itr = 0; itr < totalIterations; itr++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
itr);
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration);
|
||||
if (itr == totalIterations - 1) {
|
||||
TestCase.assertFalse(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertTrue(es);
|
||||
} else {
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStoppingForMultipleMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
// maximize_evaluation_metrics = true
|
||||
int earlyStoppingRound = 3;
|
||||
int totalIterations = 5;
|
||||
int numOfMetrics = 3;
|
||||
float[][] metrics = new float[numOfMetrics][totalIterations];
|
||||
// Only assign metric values to the first dataset, zeros for other datasets
|
||||
for (int i = 0; i < numOfMetrics; i++) {
|
||||
for (int j = 0; j < totalIterations; j++) {
|
||||
metrics[0][j] = j;
|
||||
}
|
||||
}
|
||||
int bestIteration;
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = i;
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
|
||||
// when we have multiple datasets, only the last one was used to determinate early stop
|
||||
// Here we changed the metric of the first dataset, it doesn't have any effect to the final result
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
// when we have multiple datasets, the training metrics is not considered
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = i;
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
|
||||
// Now assign metric values to the last dataset.
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[1][i] = totalIterations - i;
|
||||
metrics[2][i] = totalIterations - i;
|
||||
}
|
||||
bestIteration = 0;
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
// if any metrics off, we need to stop
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
if (i >= earlyStoppingRound - 1) {
|
||||
TestCase.assertFalse(onTrack);
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||
if (i >= earlyStoppingRound) {
|
||||
TestCase.assertTrue(es);
|
||||
} else {
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDescendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
// maximize_evaluation_metrics = false
|
||||
int totalIterations = 10;
|
||||
int earlyStoppingRounds = 5;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertFalse(onTrack);
|
||||
int bestIteration = 0;
|
||||
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = totalIterations - 1;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
metrics[0][5] = 1;
|
||||
metrics[0][6] = 2;
|
||||
metrics[0][7] = 3;
|
||||
metrics[0][8] = 4;
|
||||
metrics[0][9] = 1;
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
metrics[0][4] = 1;
|
||||
metrics[0][9] = 5;
|
||||
|
||||
bestIteration = 4;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAscendMetricsWithBoundaryCondition() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
int totalIterations = 10;
|
||||
// maximize_evaluation_metrics = true
|
||||
int totalIterations = 11;
|
||||
int earlyStoppingRounds = 10;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int iter = 0; iter < totalIterations; iter++) {
|
||||
if (iter == totalIterations - 1) {
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||
TestCase.assertTrue(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||
TestCase.assertFalse(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
int bestIteration = 0;
|
||||
|
||||
for (int itr = 0; itr < totalIterations; itr++) {
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration);
|
||||
if (itr == totalIterations - 1) {
|
||||
TestCase.assertTrue(es);
|
||||
} else {
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAscendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
// maximize_evaluation_metrics = true
|
||||
int totalIterations = 10;
|
||||
int earlyStoppingRounds = 5;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||
TestCase.assertFalse(onTrack);
|
||||
int bestIteration = 0;
|
||||
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
metrics[0][5] = 9;
|
||||
metrics[0][6] = 8;
|
||||
metrics[0][7] = 7;
|
||||
metrics[0][8] = 6;
|
||||
metrics[0][9] = 9;
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = totalIterations - 1;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
metrics[0][4] = 9;
|
||||
metrics[0][9] = 4;
|
||||
|
||||
bestIteration = 4;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -362,13 +333,13 @@ public class BoosterImplTest {
|
||||
|
||||
// Make sure we've stopped early.
|
||||
for (int w = 0; w < watches.size(); w++) {
|
||||
for (int r = 0; r < earlyStoppingRound; r++) {
|
||||
for (int r = 0; r <= earlyStoppingRound; r++) {
|
||||
TestCase.assertFalse(0.0f == metrics[w][r]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int w = 0; w < watches.size(); w++) {
|
||||
for (int r = earlyStoppingRound; r < round; r++) {
|
||||
for (int r = earlyStoppingRound + 1; r < round; r++) {
|
||||
TestCase.assertEquals(0.0f, metrics[w][r]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user