[jvm-packages]support multiple validation datasets in Spark (#3910)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * wrap iterators * enable copartition training and validationset * add parameters * converge code path and have init unit test * enable multi evals for ranking * unit test and doc * update example * fix early stopping * address the offline comments * udpate doc * test eval metrics * fix compilation issue * fix example
This commit is contained in:
@@ -222,14 +222,19 @@ public class XGBoost {
|
||||
if (iter < earlyStoppingRounds - 1) {
|
||||
return true;
|
||||
}
|
||||
float[] criterion = metrics[metrics.length - 1];
|
||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||
// as `onTrack`
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
|
||||
float[] criterion = metrics[metricsId];
|
||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||
// as `onTrack`
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
if (!onTrack) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return onTrack;
|
||||
}
|
||||
|
||||
@@ -185,6 +185,51 @@ public class BoosterImplTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStoppingForMultipleMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
int earlyStoppingRound = 3;
|
||||
int totalIterations = 5;
|
||||
int numOfMetrics = 3;
|
||||
float[][] metrics = new float[numOfMetrics][totalIterations];
|
||||
for (int i = 0; i < numOfMetrics; i++) {
|
||||
for (int j = 0; j < totalIterations; j++) {
|
||||
metrics[0][j] = j;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
// when we have multiple datasets, the training metrics is not considered
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[1][i] = totalIterations - i;
|
||||
}
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
// if any metrics off, we need to stop
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
if (i >= earlyStoppingRound - 1) {
|
||||
TestCase.assertFalse(onTrack);
|
||||
} else {
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDescendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
|
||||
Reference in New Issue
Block a user