[jvm-packages] Prevent dispose being called on unfinalized JBooster (#3005)
* [jvm-packages] Prevent dispose being called twice when finalize * Convert SIGSEGV to XGBoostError * Avoid creating a new SBooster with the same JBooster * Address CR Comments
This commit is contained in:
parent
9747ea2acb
commit
65fb4e3f5c
@ -16,8 +16,6 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import com.esotericsoftware.kryo.io.{Output, Input}
|
||||
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||
@ -204,11 +202,6 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.dispose()
|
||||
}
|
||||
|
||||
override def finalize(): Unit = {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
|
||||
override def write(kryo: Kryo, output: Output): Unit = {
|
||||
kryo.writeObject(output, booster)
|
||||
}
|
||||
|
||||
@ -66,7 +66,12 @@ object XGBoost {
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
new Booster(xgboostInJava)
|
||||
if (booster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
booster
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -198,4 +198,16 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
trainBoosterWithFastHisto(trainMat, Map("training" -> trainMat),
|
||||
round = 10, paramMap, 0.85f)
|
||||
}
|
||||
|
||||
test("test training from existing model in scala") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "auc").toMap
|
||||
|
||||
val prevBooster = XGBoost.train(trainMat, paramMap, round = 2)
|
||||
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
|
||||
assert(prevBooster == nextBooster)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user