[jvm-packages] Use UBJ for checkpoints. (#9954)
This commit is contained in:
parent
38dd91f491
commit
3976455af9
@ -50,13 +50,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
manager.updateCheckpoint(model2._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "1.model")
|
||||
assert(files.head.getPath.getName == "1.ubj")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "3.model")
|
||||
assert(files.head.getPath.getName == "3.ubj")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||
}
|
||||
|
||||
@ -66,10 +66,10 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
manager.cleanUpHigherVersions(3)
|
||||
assert(new File(s"$tmpPath/3.model").exists())
|
||||
assert(new File(s"$tmpPath/3.ubj").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(2)
|
||||
assert(!new File(s"$tmpPath/3.model").exists())
|
||||
assert(!new File(s"$tmpPath/3.ubj").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
@ -105,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
||||
assert(files.head.getPath.getName == "4.ubj")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
|
||||
@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path;
|
||||
public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private String modelSuffix = ".ubj";
|
||||
private Path checkpointPath; // directory for checkpoints
|
||||
private FileSystem fs;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user