[jvm-packages] Create demo and test for xgboost4j early stopping. (#7252)
This commit is contained in:
@@ -16,8 +16,6 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
@@ -347,6 +345,34 @@ public class BoosterImplTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStoppingAttributes() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||
watches.put("training", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
int round = 30;
|
||||
int earlyStoppingRound = 4;
|
||||
float[][] metrics = new float[watches.size()][round];
|
||||
|
||||
Booster booster = XGBoost.train(trainMat, paramMap, round,
|
||||
watches, metrics, null, null, earlyStoppingRound);
|
||||
|
||||
int bestIter = Integer.valueOf(booster.getAttr("best_iteration"));
|
||||
float bestScore = Float.valueOf(booster.getAttr("best_score"));
|
||||
TestCase.assertEquals(bestIter, round - 1);
|
||||
TestCase.assertEquals(bestScore, metrics[watches.size() - 1][round - 1]);
|
||||
}
|
||||
|
||||
private void testWithQuantileHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
||||
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
||||
float[][] metrics = new float[watches.size()][round];
|
||||
|
||||
Reference in New Issue
Block a user