[jvm-packages] Use UBJ for checkpoints. (#9954)

This commit is contained in:
Jiaming Yuan 2024-01-08 13:26:12 +08:00 committed by GitHub
parent 38dd91f491
commit 3976455af9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 7 deletions

View File

@ -50,13 +50,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "1.model")
assert(files.head.getPath.getName == "1.ubj")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "3.model")
assert(files.head.getPath.getName == "3.ubj")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}
@ -66,10 +66,10 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())
assert(new File(s"$tmpPath/3.ubj").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.model").exists())
assert(!new File(s"$tmpPath/3.ubj").exists())
}
test("test checkpoint rounds") {
@ -105,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
assert(files.head.getPath.getName == "4.ubj")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))

View File

@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path;
public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private String modelSuffix = ".ubj";
private Path checkpointPath; // directory for checkpoints
private FileSystem fs;