[jvm-packages] add configuration flag to control whether to cache transformed training set (#4268)

* control whether to cache data

* uncache
This commit is contained in:
Nan Zhu
2019-03-18 10:13:28 +08:00
committed by GitHub
parent 29a1356669
commit 359ed9c5bc
3 changed files with 113 additions and 42 deletions

View File

@@ -286,6 +286,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(error(nextModel._booster) < 0.1)
}
test("training with checkpoint boosters with cached training dataset") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}
test("repartitionForTrainingGroup with group data") {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {