[jvm-packages] do not use multiple jobs to make checkpoints (#5082)

* temp

* temp

* tep

* address the comments

* fix stylistic issues

* fix

* external checkpoint
This commit is contained in:
Nan Zhu
2020-02-01 19:36:39 -08:00
committed by GitHub
parent fa26313feb
commit d7b45fbcaf
14 changed files with 464 additions and 320 deletions

View File

@@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
import org.apache.hadoop.fs.FileSystem
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
extends JavaECM(checkpointPath, fs) {
def updateCheckpoint(booster: Booster): Unit = {
super.updateCheckpoint(booster.booster)
}
def loadCheckpointAsScalaBooster(): Booster = {
val loadedBooster = super.loadCheckpointAsBooster()
if (loadedBooster == null) {
null
} else {
new Booster(loadedBooster)
}
}
}

View File

@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
/**
* XGBoost Scala Training function.
*/
object XGBoost {
private[scala] def trainAndSaveCheckpoint(
dtrain: DMatrix,
params: Map[String, Any],
numRounds: Int,
watches: Map[String, DMatrix] = Map(),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
prevBooster: Booster,
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (prevBooster == null) {
null
} else {
prevBooster.booster
}
val xgboostInJava = checkpointParams.
map(cp => {
JXGBoost.trainAndSaveCheckpoint(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
cp.checkpointInterval,
cp.checkpointPath,
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
}).
getOrElse(
JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
)
if (prevBooster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
prevBooster
}
}
/**
* Train a booster given parameters.
*
@@ -55,23 +101,8 @@ object XGBoost {
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
if (booster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
booster
}
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
booster, None)
}
/**
@@ -126,3 +157,41 @@ object XGBoost {
new Booster(xgboostInJava)
}
}
private[scala] case class ExternalCheckpointParams(
checkpointInterval: Int,
checkpointPath: String,
skipCleanCheckpoint: Boolean)
private[scala] object ExternalCheckpointParams {
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None | Some(null) | Some("") => null
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
s" an instance of String, but current value is ${params("checkpoint_path")}")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
if (checkpointPath == null || checkpointInterval == 0) {
None
} else {
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
}
}
}