diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala index e6835158d..729bd9c77 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala @@ -50,13 +50,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite manager.updateCheckpoint(model2._booster.booster) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) - assert(files.head.getPath.getName == "1.model") + assert(files.head.getPath.getName == "1.ubj") assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2) manager.updateCheckpoint(model4._booster) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) - assert(files.head.getPath.getName == "3.model") + assert(files.head.getPath.getName == "3.ubj") assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4) } @@ -66,10 +66,10 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) manager.updateCheckpoint(model4._booster) manager.cleanUpHigherVersions(3) - assert(new File(s"$tmpPath/3.model").exists()) + assert(new File(s"$tmpPath/3.ubj").exists()) manager.cleanUpHigherVersions(2) - assert(!new File(s"$tmpPath/3.model").exists()) + assert(!new File(s"$tmpPath/3.ubj").exists()) } test("test checkpoint rounds") { @@ -105,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite // Check only one model is kept after training val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) - assert(files.head.getPath.getName == "4.model") - val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model") + assert(files.head.getPath.getName == "4.ubj") + val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj") // Train next model based on prev model val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) assert(error(tmpModel) >= error(prevModel._booster)) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java index 3d794756d..d5b8e8b9c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path; public class ExternalCheckpointManager { private Log logger = LogFactory.getLog("ExternalCheckpointManager"); - private String modelSuffix = ".model"; + private String modelSuffix = ".ubj"; private Path checkpointPath; // directory for checkpoints private FileSystem fs;