[jvm-packages] do not use multiple jobs to make checkpoints (#5082)
* temp * temp * tep * address the comments * fix stylistic issues * fix * external checkpoint
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
||||
extends JavaECM(checkpointPath, fs) {
|
||||
|
||||
def updateCheckpoint(booster: Booster): Unit = {
|
||||
super.updateCheckpoint(booster.booster)
|
||||
}
|
||||
|
||||
def loadCheckpointAsScalaBooster(): Booster = {
|
||||
val loadedBooster = super.loadCheckpointAsBooster()
|
||||
if (loadedBooster == null) {
|
||||
null
|
||||
} else {
|
||||
new Booster(loadedBooster)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
/**
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
private[scala] def trainAndSaveCheckpoint(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
numRounds: Int,
|
||||
watches: Map[String, DMatrix] = Map(),
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
prevBooster: Booster,
|
||||
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (prevBooster == null) {
|
||||
null
|
||||
} else {
|
||||
prevBooster.booster
|
||||
}
|
||||
val xgboostInJava = checkpointParams.
|
||||
map(cp => {
|
||||
JXGBoost.trainAndSaveCheckpoint(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||
cp.checkpointInterval,
|
||||
cp.checkpointPath,
|
||||
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
||||
}).
|
||||
getOrElse(
|
||||
JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
)
|
||||
if (prevBooster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
prevBooster
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
@@ -55,23 +101,8 @@ object XGBoost {
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
booster: Booster = null): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (booster == null) {
|
||||
null
|
||||
} else {
|
||||
booster.booster
|
||||
}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
if (booster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
booster
|
||||
}
|
||||
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
||||
booster, None)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -126,3 +157,41 @@ object XGBoost {
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] case class ExternalCheckpointParams(
|
||||
checkpointInterval: Int,
|
||||
checkpointPath: String,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[scala] object ExternalCheckpointParams {
|
||||
|
||||
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None | Some(null) | Some("") => null
|
||||
case Some(path: String) => path
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||
s" an instance of String, but current value is ${params("checkpoint_path")}")
|
||||
}
|
||||
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
|
||||
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
if (checkpointPath == null || checkpointInterval == 0) {
|
||||
None
|
||||
} else {
|
||||
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user