[XGBoost4J-Spark] bestIteration and bestScore for early stopping (#7095)
This commit is contained in:
parent
d7c14496d2
commit
9f7f8b976d
@ -242,11 +242,15 @@ public class XGBoost {
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
bestIteration = iter;
|
||||
booster.setAttr("best_iteration", String.valueOf(bestIteration));
|
||||
booster.setAttr("best_score", String.valueOf(bestScore));
|
||||
}
|
||||
} else {
|
||||
if (score < bestScore) {
|
||||
bestScore = score;
|
||||
bestIteration = iter;
|
||||
booster.setAttr("best_iteration", String.valueOf(bestIteration));
|
||||
booster.setAttr("best_score", String.valueOf(bestScore));
|
||||
}
|
||||
}
|
||||
if (earlyStoppingRounds > 0) {
|
||||
|
||||
@ -643,7 +643,7 @@ public class BoosterImplTest {
|
||||
}});
|
||||
|
||||
Map<String, String> attr = booster.getAttrs();
|
||||
TestCase.assertEquals(attr.size(), 4);
|
||||
TestCase.assertEquals(attr.size(), 6);
|
||||
TestCase.assertEquals(attr.get("testKey1"), "testValue2");
|
||||
TestCase.assertEquals(attr.get("aa"), "AA");
|
||||
TestCase.assertEquals(attr.get("bb"), "BB");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user