From 65fb4e3f5ca52498b7244bd59f0a438a42dc67b5 Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Sat, 6 Jan 2018 09:46:52 -0800 Subject: [PATCH] [jvm-packages] Prevent dispose being called on unfinalized JBooster (#3005) * [jvm-packages] Prevent dispose being called twice when finalize * Convert SIGSEGV to XGBoostError * Avoid creating a new SBooster with the same JBooster * Address CR Comments --- .../main/scala/ml/dmlc/xgboost4j/scala/Booster.scala | 7 ------- .../main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala | 7 ++++++- .../dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala | 12 ++++++++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 4d0c839f2..a82294974 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -16,8 +16,6 @@ package ml.dmlc.xgboost4j.scala -import java.io.IOException - import com.esotericsoftware.kryo.io.{Output, Input} import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import ml.dmlc.xgboost4j.java.{Booster => JBooster} @@ -204,11 +202,6 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) booster.dispose() } - override def finalize(): Unit = { - super.finalize() - dispose - } - override def write(kryo: Kryo, output: Output): Unit = { kryo.writeObject(output, booster) } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 76c04921a..609d7b2cd 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -66,7 +66,12 @@ object XGBoost { // we have to filter null value for customized obj and eval params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) - new Booster(xgboostInJava) + if (booster == null) { + new Booster(xgboostInJava) + } else { + // Avoid creating a new SBooster with the same JBooster + booster + } } /** diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala index 2c3ce62a7..1791c4240 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -198,4 +198,16 @@ class ScalaBoosterImplSuite extends FunSuite { trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat), round = 10, paramMap, 0.85f) } + + test("test training from existing model in scala") { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val paramMap = List("max_depth" -> "0", "silent" -> "0", + "objective" -> "binary:logistic", "tree_method" -> "hist", + "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2", + "eval_metric" -> "auc").toMap + + val prevBooster = XGBoost.train(trainMat, paramMap, round = 2) + val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster) + assert(prevBooster == nextBooster) + } }