|
|
|
|
@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
|
|
|
|
|
|
|
|
|
|
import java.io.File
|
|
|
|
|
|
|
|
|
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
|
|
|
|
import org.scalatest.FunSuite
|
|
|
|
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
|
|
|
|
import org.scalatest.{FunSuite, Ignore}
|
|
|
|
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
|
|
|
|
|
|
|
|
|
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
|
|
|
|
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
|
|
|
|
|
|
|
|
|
private lazy val (model4, model8) = {
|
|
|
|
|
val training = buildDataFrame(Classification.train)
|
|
|
|
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
|
|
|
|
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
|
|
|
|
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
|
|
|
|
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
|
|
|
|
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
|
|
|
|
Map[String, Any] = {
|
|
|
|
|
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
|
|
|
|
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
|
|
|
|
|
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def createNewModels():
|
|
|
|
|
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
|
|
|
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
|
|
|
|
val (model4, model8) = {
|
|
|
|
|
val training = buildDataFrame(Classification.train)
|
|
|
|
|
val paramMap = produceParamMap(tmpPath, 2)
|
|
|
|
|
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
|
|
|
|
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
|
|
|
|
}
|
|
|
|
|
(tmpPath, model4, model8)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
test("test update/load models") {
|
|
|
|
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
|
|
|
|
val manager = new CheckpointManager(sc, tmpPath)
|
|
|
|
|
manager.updateCheckpoint(model4._booster)
|
|
|
|
|
val (tmpPath, model4, model8) = createNewModels()
|
|
|
|
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
|
|
|
|
|
|
|
|
|
manager.updateCheckpoint(model4._booster.booster)
|
|
|
|
|
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
|
|
|
assert(files.length == 1)
|
|
|
|
|
assert(files.head.getPath.getName == "4.model")
|
|
|
|
|
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
|
|
|
|
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
|
|
|
|
|
|
|
|
|
manager.updateCheckpoint(model8._booster)
|
|
|
|
|
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
|
|
|
assert(files.length == 1)
|
|
|
|
|
assert(files.head.getPath.getName == "8.model")
|
|
|
|
|
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
|
|
|
|
|
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
test("test cleanUpHigherVersions") {
|
|
|
|
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
|
|
|
|
val manager = new CheckpointManager(sc, tmpPath)
|
|
|
|
|
val (tmpPath, model4, model8) = createNewModels()
|
|
|
|
|
|
|
|
|
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
|
|
|
|
manager.updateCheckpoint(model8._booster)
|
|
|
|
|
manager.cleanUpHigherVersions(round = 8)
|
|
|
|
|
manager.cleanUpHigherVersions(8)
|
|
|
|
|
assert(new File(s"$tmpPath/8.model").exists())
|
|
|
|
|
|
|
|
|
|
manager.cleanUpHigherVersions(round = 4)
|
|
|
|
|
manager.cleanUpHigherVersions(4)
|
|
|
|
|
assert(!new File(s"$tmpPath/8.model").exists())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
test("test checkpoint rounds") {
|
|
|
|
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
|
|
|
|
val manager = new CheckpointManager(sc, tmpPath)
|
|
|
|
|
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
|
|
|
|
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
|
|
|
|
import scala.collection.JavaConverters._
|
|
|
|
|
val (tmpPath, model4, model8) = createNewModels()
|
|
|
|
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
|
|
|
|
assertResult(Seq(7))(
|
|
|
|
|
manager.getCheckpointRounds(0, 7).asScala)
|
|
|
|
|
assertResult(Seq(2, 4, 6, 7))(
|
|
|
|
|
manager.getCheckpointRounds(2, 7).asScala)
|
|
|
|
|
manager.updateCheckpoint(model4._booster)
|
|
|
|
|
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
|
|
|
|
assertResult(Seq(4, 6, 7))(
|
|
|
|
|
manager.getCheckpointRounds(2, 7).asScala)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|
|
|
|
val testDM = new DMatrix(Classification.test.iterator)
|
|
|
|
|
|
|
|
|
|
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
|
|
|
|
|
|
|
|
|
val paramMap = produceParamMap(tmpPath, 2)
|
|
|
|
|
|
|
|
|
|
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
|
|
|
|
val skipCleanCheckpointMap =
|
|
|
|
|
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
|
|
|
|
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
|
|
|
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
|
|
|
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
|
|
|
|
skipCleanCheckpointMap
|
|
|
|
|
|
|
|
|
|
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
|
|
|
|
def error(model: Booster): Float = eval.eval(
|
|
|
|
|
model.predict(testDM, outPutMargin = true), testDM)
|
|
|
|
|
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
|
|
|
|
|
|
|
|
|
|
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
|
|
|
|
|
|
|
|
|
|
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
|
|
|
|
|
|
|
|
|
|
if (skipCleanCheckpoint) {
|
|
|
|
|
// Check only one model is kept after training
|
|
|
|
|
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|
|
|
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
|
|
|
|
// Train next model based on prev model
|
|
|
|
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
|
|
|
|
assert(error(tmpModel) > error(prevModel._booster))
|
|
|
|
|
assert(error(tmpModel) >= error(prevModel._booster))
|
|
|
|
|
assert(error(prevModel._booster) > error(nextModel._booster))
|
|
|
|
|
assert(error(nextModel._booster) < 0.1)
|
|
|
|
|
} else {
|