[jvm-packages] Update docs and unify the terminology (#3024)

* [jvm-packages] Move cache files to tmp dir and delete on exit

* [jvm-packages] Update docs and unify terminology

* Address CR Comments
This commit is contained in:
Yun Ni
2018-01-16 08:16:55 -08:00
committed by Sergei Lebedev
parent 84ab74f3a5
commit 3f3f54bcad
5 changed files with 65 additions and 58 deletions

View File

@@ -45,23 +45,23 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model4)
manager.updateCheckpoint(model4)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadBooster.booster.getVersion == 4)
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
manager.updateModel(model8)
manager.updateCheckpoint(model8)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadBooster.booster.getVersion == 8)
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model8)
manager.updateCheckpoint(model8)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
@@ -69,12 +69,12 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test saving rounds") {
test("test checkpoint rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
manager.updateModel(model4)
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
manager.updateCheckpoint(model4)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
}
}

View File

@@ -338,7 +338,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
}
}
test("training with saving checkpoint boosters") {
test("training with checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
@@ -347,7 +347,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"saving_frequency" -> 2).toMap
"checkpoint_interval" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(