diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 4a84f29af..b2d0624c2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -242,11 +242,15 @@ public class XGBoost { if (score > bestScore) { bestScore = score; bestIteration = iter; + booster.setAttr("best_iteration", String.valueOf(bestIteration)); + booster.setAttr("best_score", String.valueOf(bestScore)); } } else { if (score < bestScore) { bestScore = score; bestIteration = iter; + booster.setAttr("best_iteration", String.valueOf(bestIteration)); + booster.setAttr("best_score", String.valueOf(bestScore)); } } if (earlyStoppingRounds > 0) { diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 0b0e4cb0d..700603b96 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -643,7 +643,7 @@ public class BoosterImplTest { }}); Map attr = booster.getAttrs(); - TestCase.assertEquals(attr.size(), 4); + TestCase.assertEquals(attr.size(), 6); TestCase.assertEquals(attr.get("testKey1"), "testValue2"); TestCase.assertEquals(attr.get("aa"), "AA"); TestCase.assertEquals(attr.get("bb"), "BB");