[breaking] [jvm-packages] Remove rabit check point. (#9599)

- Add `numBoostedRound` to jvm packages
- Remove rabit checkpoint version.
- Change the starting version of training continuation in JVM [breaking].
- Redefine the checkpoint version policy in jvm package. [breaking]
- Rename the Python check point callback parameter. [breaking]
- Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
Jiaming Yuan
2023-09-26 18:06:34 +08:00
committed by GitHub
parent 7901a299b2
commit c75a3bc0a9
15 changed files with 138 additions and 229 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val (model2, model4) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
(tmpPath, model2, model4)
}
test("test update/load models") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster.booster)
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
manager.updateCheckpoint(model8._booster)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}
test("test cleanUpHigherVersions") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())
manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.model").exists())
}
test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}
@@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))