[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)

* [jvm-packages] Train Booster from an existing model

* Align Scala API with Java API

* Existing model should not load rabit checkpoint

* Address minor comments

* Implement saving temporary boosters and loading previous booster

* Add more unit tests for loadPrevBooster

* Add params to XGBoostEstimator

* (1) Move repartition out of the temp model saving loop (2) Address CR comments

* Catch a corner case of training next model with fewer rounds

* Address comments

* Refactor newly added methods into TmpBoosterManager

* Add two files which is missing in previous commit

* Rename TmpBooster to checkpoint
This commit is contained in:
Yun Ni
2017-12-29 08:36:41 -08:00
committed by Nan Zhu
parent eedca8c8ec
commit 9004ca03ca
11 changed files with 481 additions and 60 deletions

View File

@@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
class Booster private[xgboost4j](private var booster: JBooster)
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
extends Serializable with KryoSerializable {
/**

View File

@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
import scala.collection.JavaConverters._
/**
@@ -41,6 +41,7 @@ object XGBoost {
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
@@ -52,13 +53,19 @@ object XGBoost {
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0): Booster = {
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound)
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
new Booster(xgboostInJava)
}