From 9f7f8b976decb761087719249e6c6b140ad1dac2 Mon Sep 17 00:00:00 2001 From: naveenkb Date: Mon, 19 Jul 2021 16:16:49 +0530 Subject: [PATCH] [XGBoost4J-Spark] bestIteration and bestScore for early stopping (#7095) --- .../src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java | 4 ++++ .../src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 4a84f29af..b2d0624c2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -242,11 +242,15 @@ public class XGBoost { if (score > bestScore) { bestScore = score; bestIteration = iter; + booster.setAttr("best_iteration", String.valueOf(bestIteration)); + booster.setAttr("best_score", String.valueOf(bestScore)); } } else { if (score < bestScore) { bestScore = score; bestIteration = iter; + booster.setAttr("best_iteration", String.valueOf(bestIteration)); + booster.setAttr("best_score", String.valueOf(bestScore)); } } if (earlyStoppingRounds > 0) { diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 0b0e4cb0d..700603b96 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -643,7 +643,7 @@ public class BoosterImplTest { }}); Map attr = booster.getAttrs(); - TestCase.assertEquals(attr.size(), 4); + TestCase.assertEquals(attr.size(), 6); TestCase.assertEquals(attr.get("testKey1"), "testValue2"); TestCase.assertEquals(attr.get("aa"), "AA"); TestCase.assertEquals(attr.get("bb"), "BB");