[jvm-packages]support multiple validation datasets in Spark (#3910)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* wrap iterators

* enable copartition training and validationset

* add parameters

* converge code path and have init unit test

* enable multi evals for ranking

* unit test and doc

* update example

* fix early stopping

* address the offline comments

* udpate doc

* test eval metrics

* fix compilation issue

* fix example
This commit is contained in:
Nan Zhu
2018-12-17 21:03:57 -08:00
committed by GitHub
parent c8c7b9649c
commit c055a32609
14 changed files with 477 additions and 136 deletions

View File

@@ -222,14 +222,19 @@ public class XGBoost {
if (iter < earlyStoppingRounds - 1) {
return true;
}
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
// as `onTrack`
onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
float[] criterion = metrics[metricsId];
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
// as `onTrack`
onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!onTrack) {
return false;
}
}
return onTrack;
}

View File

@@ -185,6 +185,51 @@ public class BoosterImplTest {
}
}
@Test
public void testEarlyStoppingForMultipleMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
int earlyStoppingRound = 3;
int totalIterations = 5;
int numOfMetrics = 3;
float[][] metrics = new float[numOfMetrics][totalIterations];
for (int i = 0; i < numOfMetrics; i++) {
for (int j = 0; j < totalIterations; j++) {
metrics[0][j] = j;
}
}
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
}
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
// when we have multiple datasets, the training metrics is not considered
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
}
for (int i = 0; i < totalIterations; i++) {
metrics[1][i] = totalIterations - i;
}
for (int i = 0; i < totalIterations; i++) {
// if any metrics off, we need to stop
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
if (i >= earlyStoppingRound - 1) {
TestCase.assertFalse(onTrack);
} else {
TestCase.assertTrue(onTrack);
}
}
}
@Test
public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {