[jvm-packages] do not use multiple jobs to make checkpoints (#5082)

* temp

* temp

* tep

* address the comments

* fix stylistic issues

* fix

* external checkpoint
This commit is contained in:
Nan Zhu
2020-02-01 19:36:39 -08:00
committed by GitHub
parent fa26313feb
commit d7b45fbcaf
14 changed files with 464 additions and 320 deletions

View File

@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.{FunSuite, Ignore}
import org.apache.hadoop.fs.{FileSystem, Path}
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
private lazy val (model4, model8) = {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
}
test("test update/load models") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model4._booster)
val (tmpPath, model4, model8) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
val (tmpPath, model4, model8) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4)
manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test checkpoint rounds") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
}
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = produceParamMap(tmpPath, 2)
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
if (skipCleanCheckpoint) {
// Check only one model is kept after training
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(tmpModel) >= error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
} else {

View File

@@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
" stop the application") {
val spark = ss
import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
@@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
"does not stop application") {
val spark = ss
import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(

View File

@@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.apache.spark.ml.param.ParamMap

View File

@@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite
import org.scalatest.{FunSuite, Ignore}
class RabitRobustnessSuite extends FunSuite with PerTest {