[XGBoost4J-Spark] bestIteration and bestScore for early stopping (#7095)

This commit is contained in:
naveenkb 2021-07-19 16:16:49 +05:30 committed by GitHub
parent d7c14496d2
commit 9f7f8b976d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 1 deletions

View File

@ -242,11 +242,15 @@ public class XGBoost {
if (score > bestScore) {
bestScore = score;
bestIteration = iter;
booster.setAttr("best_iteration", String.valueOf(bestIteration));
booster.setAttr("best_score", String.valueOf(bestScore));
}
} else {
if (score < bestScore) {
bestScore = score;
bestIteration = iter;
booster.setAttr("best_iteration", String.valueOf(bestIteration));
booster.setAttr("best_score", String.valueOf(bestScore));
}
}
if (earlyStoppingRounds > 0) {

View File

@ -643,7 +643,7 @@ public class BoosterImplTest {
}});
Map<String, String> attr = booster.getAttrs();
TestCase.assertEquals(attr.size(), 4);
TestCase.assertEquals(attr.size(), 6);
TestCase.assertEquals(attr.get("testKey1"), "testValue2");
TestCase.assertEquals(attr.get("aa"), "AA");
TestCase.assertEquals(attr.get("bb"), "BB");