[jvm-packages] Create demo and test for xgboost4j early stopping. (#7252)

This commit is contained in:
Jiaming Yuan
2021-09-25 03:29:27 +08:00
committed by GitHub
parent 0ee11dac77
commit fbd58bf190
5 changed files with 103 additions and 9 deletions

View File

@@ -16,8 +16,6 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
@@ -347,6 +345,34 @@ public class BoosterImplTest {
}
}
@Test
public void testEarlyStoppingAttributes() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix(this.train_uri);
DMatrix testMat = new DMatrix(this.test_uri);
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
Map<String, DMatrix> watches = new LinkedHashMap<>();
watches.put("training", trainMat);
watches.put("test", testMat);
int round = 30;
int earlyStoppingRound = 4;
float[][] metrics = new float[watches.size()][round];
Booster booster = XGBoost.train(trainMat, paramMap, round,
watches, metrics, null, null, earlyStoppingRound);
int bestIter = Integer.valueOf(booster.getAttr("best_iteration"));
float bestScore = Float.valueOf(booster.getAttr("best_score"));
TestCase.assertEquals(bestIter, round - 1);
TestCase.assertEquals(bestScore, metrics[watches.size() - 1][round - 1]);
}
private void testWithQuantileHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
Map<String, Object> paramMap, float threshold) throws XGBoostError {
float[][] metrics = new float[watches.size()][round];