[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)
* [jvm-packages] Train Booster from an existing model * Align Scala API with Java API * Existing model should not load rabit checkpoint * Address minor comments * Implement saving temporary boosters and loading previous booster * Add more unit tests for loadPrevBooster * Add params to XGBoostEstimator * (1) Move repartition out of the temp model saving loop (2) Address CR comments * Catch a corner case of training next model with fewer rounds * Address comments * Refactor newly added methods into TmpBoosterManager * Add two files which is missing in previous commit * Rename TmpBooster to checkpoint
This commit is contained in:
@@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
class Booster private[xgboost4j](private var booster: JBooster)
|
||||
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
extends Serializable with KryoSerializable {
|
||||
|
||||
/**
|
||||
|
||||
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
@@ -41,6 +41,7 @@ object XGBoost {
|
||||
* increases in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
@@ -52,13 +53,19 @@ object XGBoost {
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0): Booster = {
|
||||
earlyStoppingRound: Int = 0,
|
||||
booster: Booster = null): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (booster == null) {
|
||||
null
|
||||
} else {
|
||||
booster.booster
|
||||
}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound)
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user