[jvm-packages] Use UBJ for checkpoints. (#9954)
This commit is contained in:
parent
38dd91f491
commit
3976455af9
@ -50,13 +50,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
|||||||
manager.updateCheckpoint(model2._booster.booster)
|
manager.updateCheckpoint(model2._booster.booster)
|
||||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "1.model")
|
assert(files.head.getPath.getName == "1.ubj")
|
||||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||||
|
|
||||||
manager.updateCheckpoint(model4._booster)
|
manager.updateCheckpoint(model4._booster)
|
||||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "3.model")
|
assert(files.head.getPath.getName == "3.ubj")
|
||||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,10 +66,10 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
|||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
manager.updateCheckpoint(model4._booster)
|
manager.updateCheckpoint(model4._booster)
|
||||||
manager.cleanUpHigherVersions(3)
|
manager.cleanUpHigherVersions(3)
|
||||||
assert(new File(s"$tmpPath/3.model").exists())
|
assert(new File(s"$tmpPath/3.ubj").exists())
|
||||||
|
|
||||||
manager.cleanUpHigherVersions(2)
|
manager.cleanUpHigherVersions(2)
|
||||||
assert(!new File(s"$tmpPath/3.model").exists())
|
assert(!new File(s"$tmpPath/3.ubj").exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test checkpoint rounds") {
|
test("test checkpoint rounds") {
|
||||||
@ -105,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
|||||||
// Check only one model is kept after training
|
// Check only one model is kept after training
|
||||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "4.model")
|
assert(files.head.getPath.getName == "4.ubj")
|
||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
|
||||||
// Train next model based on prev model
|
// Train next model based on prev model
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||||
assert(error(tmpModel) >= error(prevModel._booster))
|
assert(error(tmpModel) >= error(prevModel._booster))
|
||||||
|
|||||||
@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path;
|
|||||||
public class ExternalCheckpointManager {
|
public class ExternalCheckpointManager {
|
||||||
|
|
||||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||||
private String modelSuffix = ".model";
|
private String modelSuffix = ".ubj";
|
||||||
private Path checkpointPath; // directory for checkpoints
|
private Path checkpointPath; // directory for checkpoints
|
||||||
private FileSystem fs;
|
private FileSystem fs;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user