[jvm-packages] do not use multiple jobs to make checkpoints (#5082)

* temp

* temp

* tep

* address the comments

* fix stylistic issues

* fix

* external checkpoint
This commit is contained in:
Nan Zhu
2020-02-01 19:36:39 -08:00
committed by GitHub
parent fa26313feb
commit d7b45fbcaf
14 changed files with 464 additions and 320 deletions

View File

@@ -1,164 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
/**
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
* the job can restart training from a saved checkpoints instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality.
*
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
* for every a few iterations.
*
* @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoints
*/
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark")
private val modelSuffix = ".model"
private def getPath(version: Int) = {
s"$checkpointPath/$version$modelSuffix"
}
private def getExistingVersions: Seq[Int] = {
val fs = FileSystem.get(sc.hadoopConfiguration)
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
Seq()
} else {
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
}
}
}
def cleanPath(): Unit = {
if (checkpointPath != "") {
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
}
}
/**
* Load existing checkpoint with the highest version as a Booster object
*
* @return the booster with the highest version, null if no checkpoints available.
*/
private[spark] def loadCheckpointAsBooster: Booster = {
val versions = getExistingVersions
if (versions.nonEmpty) {
val version = versions.max
val fullPath = getPath(version)
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
logger.info(s"Start training from previous booster at $fullPath")
val booster = SXGBoost.loadModel(inputStream)
booster.booster.setVersion(version)
booster
} else {
null
}
}
/**
* Clean up all previous checkpoints and save a new checkpoint
*
* @param checkpoint the checkpoint to save as an XGBoostModel
*/
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(checkpoint.getVersion)
val outputStream = fs.create(new Path(fullPath), true)
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
checkpoint.saveModel(outputStream)
prevModelPaths.foreach(path => fs.delete(path, true))
}
/**
* Clean up checkpoint boosters with version higher than or equal to the round.
*
* @param round the number of rounds in the current training job
*/
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
higherVersions.foreach { version =>
val fs = FileSystem.get(sc.hadoopConfiguration)
fs.delete(new Path(getPath(version)), true)
}
}
/**
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
* and the method returns Seq(round)
*
* @param checkpointInterval Period (in iterations) between checkpoints.
* @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints
*/
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
val prevRounds = getExistingVersions.map(_ / 2)
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
(firstCheckpointRound until round by checkpointInterval) :+ round
} else if (checkpointInterval <= 0) {
Seq(round)
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
}
}
}
object CheckpointManager {
case class CheckpointParam(
checkpointPath: String,
checkpointInterval: Int,
skipCleanCheckpoint: Boolean)
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None => ""
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
" an instance of String.")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
}
}

View File

@@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
@@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
private[this] case class XGBoostExecutionParams(
numWorkers: Int,
round: Int,
numRounds: Int,
useExternalMemory: Boolean,
obj: ObjectiveTrait,
eval: EvalTrait,
@@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
allowNonZeroForMissing: Boolean,
trackerConf: TrackerConf,
timeoutRequestWorkers: Long,
checkpointParam: CheckpointParam,
checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean) {
@@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean]
validateSparkSslConf
if (overridedParams.contains("tree_method")) {
require(overridedParams("tree_method") == "hist" ||
overridedParams("tree_method") == "approx" ||
@@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" an instance of Long.")
}
val checkpointParam =
CheckpointManager.extractParams(overridedParams)
ExternalCheckpointParams.extractParams(overridedParams)
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
.asInstanceOf[Double]
@@ -339,11 +339,9 @@ object XGBoost extends Serializable {
watches: Watches,
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
@@ -357,14 +355,23 @@ object XGBoost extends Serializable {
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try {
Rabit.init(rabitEnv)
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
val externalCheckpointParams = xgbExecutionParam.checkpointParam
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
}
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch {
case xgbException: XGBoostError =>
@@ -437,7 +444,6 @@ object XGBoost extends Serializable {
trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) {
@@ -446,8 +452,8 @@ object XGBoost extends Serializable {
processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.eval, prevBooster)
}).cache()
} else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
@@ -459,8 +465,8 @@ object XGBoost extends Serializable {
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.eval, prevBooster)
}.cache()
}
}
@@ -469,7 +475,6 @@ object XGBoost extends Serializable {
trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) {
@@ -478,7 +483,7 @@ object XGBoost extends Serializable {
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
}).cache()
} else {
@@ -490,7 +495,7 @@ object XGBoost extends Serializable {
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj,
xgbExecutionParam.eval,
prevBooster)
@@ -529,60 +534,58 @@ object XGBoost extends Serializable {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
checkpointPath)
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers)
var prevBooster = checkpointManager.loadCheckpointAsBooster
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
try {
// Train for every ${savingRound} rounds and save the partially completed booster
val producedBooster = checkpointManager.getCheckpointRounds(
xgbExecParams.checkpointParam.checkpointInterval,
xgbExecParams.round).map {
checkpointRound: Int =>
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
try {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
boostersAndMetrics, sparkJobThread)
if (checkpointRound < xgbExecParams.round) {
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics)
} finally {
tracker.stop()
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
}.last
// we should delete the checkpoint directory after a successful training
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
checkpointManager.cleanPath()
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
boostersAndMetrics, sparkJobThread)
(booster, metrics)
} finally {
tracker.stop()
}
producedBooster
// we should delete the checkpoint directory after a successful training
xgbExecParams.checkpointParam.foreach {
cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanPath()
}
}
(booster, metrics)
} catch {
case t: Throwable =>
// if the job was aborted due to an exception

View File

@@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
}
}

View File

@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.{FunSuite, Ignore}
import org.apache.hadoop.fs.{FileSystem, Path}
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
private lazy val (model4, model8) = {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
}
test("test update/load models") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model4._booster)
val (tmpPath, model4, model8) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
val (tmpPath, model4, model8) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4)
manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test checkpoint rounds") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
}
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = produceParamMap(tmpPath, 2)
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
if (skipCleanCheckpoint) {
// Check only one model is kept after training
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(tmpModel) >= error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
} else {

View File

@@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
" stop the application") {
val spark = ss
import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
@@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
"does not stop application") {
val spark = ss
import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(

View File

@@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.apache.spark.ml.param.ParamMap

View File

@@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite
import org.scalatest.{FunSuite, Ignore}
class RabitRobustnessSuite extends FunSuite with PerTest {