Save model in ubj as the default. (#9947)

This commit is contained in:
Jiaming Yuan
2024-01-05 17:53:36 +08:00
committed by GitHub
parent c03a4d5088
commit 38dd91f491
23 changed files with 598 additions and 550 deletions

View File

@@ -30,9 +30,6 @@ import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
abstract class XGBoostWriter extends MLWriter {
/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -432,6 +432,7 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
// test json
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
@@ -439,21 +440,21 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeJsonModelPath))
// test default "deprecated"
// test ubj
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))
nativeUbjModelPath))
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath1)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
nativeUbjModelPath1))
}
test("native json model file should store feature_name and feature_type") {

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -333,21 +333,24 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))
// test default "deprecated"
// test default "ubj"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
}
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
// test the deprecated format
val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
model.write.option("format", "deprecated").save(modelDeprecatedPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
}
}