[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)

* [jvm-packages] Train Booster from an existing model

* Align Scala API with Java API

* Existing model should not load rabit checkpoint

* Address minor comments

* Implement saving temporary boosters and loading previous booster

* Add more unit tests for loadPrevBooster

* Add params to XGBoostEstimator

* (1) Move repartition out of the temp model saving loop (2) Address CR comments

* Catch a corner case of training next model with fewer rounds

* Address comments

* Refactor newly added methods into TmpBoosterManager

* Add two files which is missing in previous commit

* Rename TmpBooster to checkpoint
This commit is contained in:
Yun Ni
2017-12-29 08:36:41 -08:00
committed by Nan Zhu
parent eedca8c8ec
commit 9004ca03ca
11 changed files with 481 additions and 60 deletions

View File

@@ -0,0 +1,80 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkConf, SparkContext}
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
}
private lazy val (model4, model8) = {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, sc.defaultParallelism),
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, sc.defaultParallelism))
}
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model4)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadBooster.booster.getVersion == 4)
manager.updateModel(model8)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadBooster.booster.getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model8)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4)
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test saving rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
manager.updateModel(model4)
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
}
}

View File

@@ -16,13 +16,14 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
@@ -73,13 +74,14 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
val boosterRDD = XGBoost.buildDistributedBoosters(
trainingRDD,
partitionedRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
new java.util.HashMap[String, String](),
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true,
missing = Float.NaN)
round = 5, eval = null, obj = null, useExternalMemory = true,
missing = Float.NaN, prevBooster = null)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
}
@@ -335,4 +337,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
}
}
test("training with saving checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"saving_frequency" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
nWorkers = numWorkers)
assert(error(tmpModel) > error(prevModel))
assert(error(prevModel) > error(nextModel))
assert(error(nextModel) < 0.1)
}
}