[jvm-packages] Fix early stop with xgboost4j-spark (#4176)
* Fix early stop with xgboost4j-spark * Update XGBoost.java * Update XGBoost.java * Update XGBoost.java To use -Float.MAX_VALUE as the lower bound, in case there is positive metric. * Only update best score if the current score is better (no update when equal) * Update xgboost-spark tutorial to fix early stopping docs.
This commit is contained in:
parent
7ea5675679
commit
9fefa2128d
@ -194,11 +194,11 @@ After we set XGBoostClassifier parameters and feature/label column, we can build
|
|||||||
Early Stopping
|
Early Stopping
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training.
|
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds if the evaluation metric going away from the best iteration and early stop training iterations.
|
||||||
|
|
||||||
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
|
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
|
||||||
|
|
||||||
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
|
For example, we need to maximize the evaluation metrics (set ``maximize_evaluation_metrics`` with true), and set ``num_early_stopping_rounds`` with 5. The evaluation metric of 10th iteration is the maximum one until now. In the following iterations, if there is no evaluation metric greater than the 10th iteration's (best one), the traning would be early stopped at 15th iteration.
|
||||||
|
|
||||||
Training with Evaluation Sets
|
Training with Evaluation Sets
|
||||||
----------------
|
----------------
|
||||||
|
|||||||
@ -140,6 +140,8 @@ public class XGBoost {
|
|||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
DMatrix[] evalMats;
|
DMatrix[] evalMats;
|
||||||
|
float bestScore;
|
||||||
|
int bestIteration;
|
||||||
List<String> names = new ArrayList<String>();
|
List<String> names = new ArrayList<String>();
|
||||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||||
|
|
||||||
@ -150,6 +152,12 @@ public class XGBoost {
|
|||||||
|
|
||||||
evalNames = names.toArray(new String[names.size()]);
|
evalNames = names.toArray(new String[names.size()]);
|
||||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||||
|
if (isMaximizeEvaluation(params)) {
|
||||||
|
bestScore = -Float.MAX_VALUE;
|
||||||
|
} else {
|
||||||
|
bestScore = Float.MAX_VALUE;
|
||||||
|
}
|
||||||
|
bestIteration = 0;
|
||||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
@ -196,12 +204,27 @@ public class XGBoost {
|
|||||||
for (int i = 0; i < metricsOut.length; i++) {
|
for (int i = 0; i < metricsOut.length; i++) {
|
||||||
metrics[i][iter] = metricsOut[i];
|
metrics[i][iter] = metricsOut[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there is more than one evaluation datasets, the last one would be used
|
||||||
|
// to determinate early stop.
|
||||||
|
float score = metricsOut[metricsOut.length - 1];
|
||||||
|
if (isMaximizeEvaluation(params)) {
|
||||||
|
// Update best score if the current score is better (no update when equal)
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestIteration = iter;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (score < bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestIteration = iter;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (earlyStoppingRounds > 0) {
|
if (earlyStoppingRounds > 0) {
|
||||||
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
|
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||||
if (!onTrack) {
|
|
||||||
String reversedDirection = getReversedDirection(params);
|
|
||||||
Rabit.trackerPrint(String.format(
|
Rabit.trackerPrint(String.format(
|
||||||
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
|
"early stopping after %d rounds away from the best iteration",
|
||||||
|
earlyStoppingRounds));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,42 +237,11 @@ public class XGBoost {
|
|||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|
||||||
static boolean judgeIfTrainingOnTrack(
|
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
|
||||||
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
return iter - bestIteration >= earlyStoppingRounds;
|
||||||
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
|
||||||
boolean onTrack = false;
|
|
||||||
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
|
|
||||||
if (iter < earlyStoppingRounds - 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
|
|
||||||
float[] criterion = metrics[metricsId];
|
|
||||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
|
||||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
|
||||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
|
||||||
// as `onTrack`
|
|
||||||
onTrack |= maximizeEvaluationMetrics ?
|
|
||||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
|
||||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
|
||||||
}
|
|
||||||
if (!onTrack) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return onTrack;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String getReversedDirection(Map<String, Object> params) {
|
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
||||||
String reversedDirection = null;
|
|
||||||
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
|
||||||
reversedDirection = "descending";
|
|
||||||
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
|
||||||
reversedDirection = "ascending";
|
|
||||||
}
|
|
||||||
return reversedDirection;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
|
|
||||||
try {
|
try {
|
||||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||||
assert(maximize != null);
|
assert(maximize != null);
|
||||||
|
|||||||
@ -154,188 +154,159 @@ public class BoosterImplTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDescendMetricsWithBoundaryCondition() {
|
public void testDescendMetricsWithBoundaryCondition() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
// maximize_evaluation_metrics = false
|
||||||
{
|
int totalIterations = 11;
|
||||||
put("max_depth", 3);
|
int earlyStoppingRound = 10;
|
||||||
put("silent", 1);
|
|
||||||
put("objective", "binary:logistic");
|
|
||||||
put("maximize_evaluation_metrics", "false");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
int totalIterations = 10;
|
|
||||||
int earlyStoppingRounds = 10;
|
|
||||||
float[][] metrics = new float[1][totalIterations];
|
float[][] metrics = new float[1][totalIterations];
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = i;
|
metrics[0][i] = i;
|
||||||
}
|
}
|
||||||
|
int bestIteration = 0;
|
||||||
|
|
||||||
for (int itr = 0; itr < totalIterations; itr++) {
|
for (int itr = 0; itr < totalIterations; itr++) {
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration);
|
||||||
itr);
|
|
||||||
if (itr == totalIterations - 1) {
|
if (itr == totalIterations - 1) {
|
||||||
TestCase.assertFalse(onTrack);
|
TestCase.assertTrue(es);
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
|
||||||
metrics[0][i] = totalIterations - i;
|
|
||||||
}
|
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
|
||||||
totalIterations - 1);
|
|
||||||
TestCase.assertTrue(onTrack);
|
|
||||||
} else {
|
} else {
|
||||||
TestCase.assertTrue(onTrack);
|
TestCase.assertFalse(es);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEarlyStoppingForMultipleMetrics() {
|
public void testEarlyStoppingForMultipleMetrics() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
// maximize_evaluation_metrics = true
|
||||||
{
|
|
||||||
put("max_depth", 3);
|
|
||||||
put("silent", 1);
|
|
||||||
put("objective", "binary:logistic");
|
|
||||||
put("maximize_evaluation_metrics", "true");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
int earlyStoppingRound = 3;
|
int earlyStoppingRound = 3;
|
||||||
int totalIterations = 5;
|
int totalIterations = 5;
|
||||||
int numOfMetrics = 3;
|
int numOfMetrics = 3;
|
||||||
float[][] metrics = new float[numOfMetrics][totalIterations];
|
float[][] metrics = new float[numOfMetrics][totalIterations];
|
||||||
|
// Only assign metric values to the first dataset, zeros for other datasets
|
||||||
for (int i = 0; i < numOfMetrics; i++) {
|
for (int i = 0; i < numOfMetrics; i++) {
|
||||||
for (int j = 0; j < totalIterations; j++) {
|
for (int j = 0; j < totalIterations; j++) {
|
||||||
metrics[0][j] = j;
|
metrics[0][j] = j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int bestIteration;
|
||||||
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
bestIteration = i;
|
||||||
TestCase.assertTrue(onTrack);
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||||
|
TestCase.assertFalse(es);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when we have multiple datasets, only the last one was used to determinate early stop
|
||||||
|
// Here we changed the metric of the first dataset, it doesn't have any effect to the final result
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = totalIterations - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
// when we have multiple datasets, the training metrics is not considered
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
bestIteration = i;
|
||||||
TestCase.assertTrue(onTrack);
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||||
|
TestCase.assertFalse(es);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now assign metric values to the last dataset.
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[1][i] = totalIterations - i;
|
metrics[2][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
|
bestIteration = 0;
|
||||||
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
// if any metrics off, we need to stop
|
// if any metrics off, we need to stop
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||||
if (i >= earlyStoppingRound - 1) {
|
if (i >= earlyStoppingRound) {
|
||||||
TestCase.assertFalse(onTrack);
|
TestCase.assertTrue(es);
|
||||||
} else {
|
} else {
|
||||||
TestCase.assertTrue(onTrack);
|
TestCase.assertFalse(es);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDescendMetrics() {
|
public void testDescendMetrics() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
// maximize_evaluation_metrics = false
|
||||||
{
|
|
||||||
put("max_depth", 3);
|
|
||||||
put("silent", 1);
|
|
||||||
put("objective", "binary:logistic");
|
|
||||||
put("maximize_evaluation_metrics", "false");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
int totalIterations = 10;
|
int totalIterations = 10;
|
||||||
int earlyStoppingRounds = 5;
|
int earlyStoppingRounds = 5;
|
||||||
float[][] metrics = new float[1][totalIterations];
|
float[][] metrics = new float[1][totalIterations];
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = i;
|
metrics[0][i] = i;
|
||||||
}
|
}
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
int bestIteration = 0;
|
||||||
totalIterations - 1);
|
|
||||||
TestCase.assertFalse(onTrack);
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||||
|
TestCase.assertTrue(es);
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = totalIterations - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
bestIteration = totalIterations - 1;
|
||||||
totalIterations - 1);
|
|
||||||
TestCase.assertTrue(onTrack);
|
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||||
|
TestCase.assertFalse(es);
|
||||||
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = totalIterations - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
metrics[0][5] = 1;
|
metrics[0][4] = 1;
|
||||||
metrics[0][6] = 2;
|
metrics[0][9] = 5;
|
||||||
metrics[0][7] = 3;
|
|
||||||
metrics[0][8] = 4;
|
bestIteration = 4;
|
||||||
metrics[0][9] = 1;
|
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||||
totalIterations - 1);
|
TestCase.assertTrue(es);
|
||||||
TestCase.assertTrue(onTrack);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAscendMetricsWithBoundaryCondition() {
|
public void testAscendMetricsWithBoundaryCondition() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
// maximize_evaluation_metrics = true
|
||||||
{
|
int totalIterations = 11;
|
||||||
put("max_depth", 3);
|
|
||||||
put("silent", 1);
|
|
||||||
put("objective", "binary:logistic");
|
|
||||||
put("maximize_evaluation_metrics", "true");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
int totalIterations = 10;
|
|
||||||
int earlyStoppingRounds = 10;
|
int earlyStoppingRounds = 10;
|
||||||
float[][] metrics = new float[1][totalIterations];
|
float[][] metrics = new float[1][totalIterations];
|
||||||
for (int iter = 0; iter < totalIterations; iter++) {
|
|
||||||
if (iter == totalIterations - 1) {
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
|
||||||
metrics[0][i] = i;
|
|
||||||
}
|
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
|
||||||
TestCase.assertTrue(onTrack);
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = totalIterations - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
int bestIteration = 0;
|
||||||
TestCase.assertFalse(onTrack);
|
|
||||||
|
for (int itr = 0; itr < totalIterations; itr++) {
|
||||||
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration);
|
||||||
|
if (itr == totalIterations - 1) {
|
||||||
|
TestCase.assertTrue(es);
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
TestCase.assertFalse(es);
|
||||||
metrics[0][i] = totalIterations - i;
|
|
||||||
}
|
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
|
||||||
TestCase.assertTrue(onTrack);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAscendMetrics() {
|
public void testAscendMetrics() {
|
||||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
// maximize_evaluation_metrics = true
|
||||||
{
|
|
||||||
put("max_depth", 3);
|
|
||||||
put("silent", 1);
|
|
||||||
put("objective", "binary:logistic");
|
|
||||||
put("maximize_evaluation_metrics", "true");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
int totalIterations = 10;
|
int totalIterations = 10;
|
||||||
int earlyStoppingRounds = 5;
|
int earlyStoppingRounds = 5;
|
||||||
float[][] metrics = new float[1][totalIterations];
|
float[][] metrics = new float[1][totalIterations];
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
|
||||||
metrics[0][i] = i;
|
|
||||||
}
|
|
||||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
|
||||||
TestCase.assertTrue(onTrack);
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = totalIterations - i;
|
metrics[0][i] = totalIterations - i;
|
||||||
}
|
}
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
int bestIteration = 0;
|
||||||
TestCase.assertFalse(onTrack);
|
|
||||||
|
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||||
|
TestCase.assertTrue(es);
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
metrics[0][i] = i;
|
metrics[0][i] = i;
|
||||||
}
|
}
|
||||||
metrics[0][5] = 9;
|
bestIteration = totalIterations - 1;
|
||||||
metrics[0][6] = 8;
|
|
||||||
metrics[0][7] = 7;
|
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||||
metrics[0][8] = 6;
|
TestCase.assertFalse(es);
|
||||||
metrics[0][9] = 9;
|
|
||||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
TestCase.assertTrue(onTrack);
|
metrics[0][i] = i;
|
||||||
|
}
|
||||||
|
metrics[0][4] = 9;
|
||||||
|
metrics[0][9] = 4;
|
||||||
|
|
||||||
|
bestIteration = 4;
|
||||||
|
|
||||||
|
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||||
|
TestCase.assertTrue(es);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -362,13 +333,13 @@ public class BoosterImplTest {
|
|||||||
|
|
||||||
// Make sure we've stopped early.
|
// Make sure we've stopped early.
|
||||||
for (int w = 0; w < watches.size(); w++) {
|
for (int w = 0; w < watches.size(); w++) {
|
||||||
for (int r = 0; r < earlyStoppingRound; r++) {
|
for (int r = 0; r <= earlyStoppingRound; r++) {
|
||||||
TestCase.assertFalse(0.0f == metrics[w][r]);
|
TestCase.assertFalse(0.0f == metrics[w][r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int w = 0; w < watches.size(); w++) {
|
for (int w = 0; w < watches.size(); w++) {
|
||||||
for (int r = earlyStoppingRound; r < round; r++) {
|
for (int r = earlyStoppingRound + 1; r < round; r++) {
|
||||||
TestCase.assertEquals(0.0f, metrics[w][r]);
|
TestCase.assertEquals(0.0f, metrics[w][r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user