[Blocking][jvm-packages] fix the early stopping feature (#3808)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* temp

* add method for classifier and regressor

* update tutorial

* address the comments

* update
This commit is contained in:
Nan Zhu
2018-10-23 14:53:13 -07:00
committed by GitHub
parent e26b5d63b2
commit 4ae225a08d
7 changed files with 134 additions and 14 deletions

View File

@@ -118,9 +118,9 @@ public class XGBoost {
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* @param earlyStoppingRounds if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* goes to the unexpected direction in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
@@ -134,7 +134,7 @@ public class XGBoost {
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRound,
int earlyStoppingRounds,
Booster booster) throws XGBoostError {
//collect eval matrixs
@@ -196,17 +196,14 @@ public class XGBoost {
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}
boolean decreasing = true;
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!decreasing) {
Rabit.trackerPrint(String.format(
"early stopping after %d decreasing rounds", earlyStoppingRound));
break;
if (earlyStoppingRounds > 0) {
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
if (!onTrack) {
String reversedDirection = getReversedDirection(params);
Rabit.trackerPrint(String.format(
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
break;
}
}
if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');
@@ -217,6 +214,41 @@ public class XGBoost {
return booster;
}
static boolean judgeIfTrainingOnTrack(
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
boolean onTrack = false;
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
}
return onTrack;
}
private static String getReversedDirection(Map<String, Object> params) {
String reversedDirection = null;
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "descending";
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "ascending";
}
return reversedDirection;
}
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
try {
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
assert(maximize != null);
return Boolean.valueOf(maximize);
} catch (Exception ex) {
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
" allowed value: true/false", ex);
throw ex;
}
}
/**
* Cross-validation with given parameters.
*

View File

@@ -152,6 +152,66 @@ public class BoosterImplTest {
}
}
@Test
public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
float[][] metrics = new float[1][5];
for (int i = 0; i < 5; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}
metrics[0][0] = 1;
metrics[0][2] = 5;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
}
@Test
public void testAscendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
float[][] metrics = new float[1][5];
for (int i = 0; i < 5; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = i;
}
metrics[0][0] = 6;
metrics[0][2] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
}
@Test
public void testBoosterEarlyStop() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
@@ -162,6 +222,7 @@ public class BoosterImplTest {
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
Map<String, DMatrix> watches = new LinkedHashMap<>();