[Blocking][jvm-packages] fix the early stopping feature (#3808)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * temp * add method for classifier and regressor * update tutorial * address the comments * update
This commit is contained in:
@@ -152,6 +152,66 @@ public class BoosterImplTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDescendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
float[][] metrics = new float[1][5];
|
||||
for (int i = 0; i < 5; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||
TestCase.assertFalse(onTrack);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
metrics[0][i] = 5 - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||
TestCase.assertTrue(onTrack);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
metrics[0][i] = 5 - i;
|
||||
}
|
||||
metrics[0][0] = 1;
|
||||
metrics[0][2] = 5;
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAscendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
float[][] metrics = new float[1][5];
|
||||
for (int i = 0; i < 5; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||
TestCase.assertTrue(onTrack);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
metrics[0][i] = 5 - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||
TestCase.assertFalse(onTrack);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
metrics[0][0] = 6;
|
||||
metrics[0][2] = 1;
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||
TestCase.assertTrue(onTrack);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoosterEarlyStop() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
@@ -162,6 +222,7 @@ public class BoosterImplTest {
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||
|
||||
Reference in New Issue
Block a user