merge latest, Jan 12 2024
This commit is contained in:
@@ -49,6 +49,22 @@
|
||||
<cudf.classifier>cuda11</cudf.classifier>
|
||||
<scalatest.version>3.2.17</scalatest.version>
|
||||
<scala-collection-compat.version>2.11.0</scala-collection-compat.version>
|
||||
|
||||
<!-- SPARK-36796 for JDK-17 test-->
|
||||
<extraJavaTestArgs>
|
||||
-XX:+IgnoreUnrecognizedVMOptions
|
||||
--add-opens=java.base/java.lang=ALL-UNNAMED
|
||||
--add-opens=java.base/java.lang.invoke=ALL-UNNAMED
|
||||
--add-opens=java.base/java.io=ALL-UNNAMED
|
||||
--add-opens=java.base/java.net=ALL-UNNAMED
|
||||
--add-opens=java.base/java.nio=ALL-UNNAMED
|
||||
--add-opens=java.base/java.util=ALL-UNNAMED
|
||||
--add-opens=java.base/java.util.concurrent=ALL-UNNAMED
|
||||
--add-opens=java.base/sun.nio.ch=ALL-UNNAMED
|
||||
--add-opens=java.base/sun.nio.cs=ALL-UNNAMED
|
||||
--add-opens=java.base/sun.security.action=ALL-UNNAMED
|
||||
--add-opens=java.base/sun.util.calendar=ALL-UNNAMED
|
||||
</extraJavaTestArgs>
|
||||
</properties>
|
||||
<repositories>
|
||||
<repository>
|
||||
@@ -338,6 +354,9 @@
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest-maven-plugin</artifactId>
|
||||
<version>2.2.0</version>
|
||||
<configuration>
|
||||
<argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>test</id>
|
||||
|
||||
@@ -30,9 +30,6 @@ import org.apache.spark.ml.param.Params
|
||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||
|
||||
abstract class XGBoostWriter extends MLWriter {
|
||||
|
||||
/** Currently it's using the "deprecated" format as
|
||||
* default, which will be changed into `ubj` in future releases. */
|
||||
def getModelFormat(): String = {
|
||||
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
|
||||
}
|
||||
|
||||
@@ -50,13 +50,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
manager.updateCheckpoint(model2._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "1.model")
|
||||
assert(files.head.getPath.getName == "1.ubj")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "3.model")
|
||||
assert(files.head.getPath.getName == "3.ubj")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||
}
|
||||
|
||||
@@ -66,10 +66,10 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
manager.cleanUpHigherVersions(3)
|
||||
assert(new File(s"$tmpPath/3.model").exists())
|
||||
assert(new File(s"$tmpPath/3.ubj").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(2)
|
||||
assert(!new File(s"$tmpPath/3.model").exists())
|
||||
assert(!new File(s"$tmpPath/3.ubj").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
@@ -105,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
||||
assert(files.head.getPath.getName == "4.ubj")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -432,6 +432,7 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
|
||||
// test json
|
||||
val modelPath = new File(tempDir.toFile, "xgbc").getPath
|
||||
model.write.option("format", "json").save(modelPath)
|
||||
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
|
||||
@@ -439,21 +440,21 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeJsonModelPath))
|
||||
|
||||
// test default "deprecated"
|
||||
// test ubj
|
||||
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||
model.write.save(modelUbjPath)
|
||||
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
|
||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeDeprecatedModelPath))
|
||||
nativeUbjModelPath))
|
||||
|
||||
// json file should be indifferent with ubj file
|
||||
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||
model.write.option("format", "json").save(modelJsonPath)
|
||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||
model.nativeBooster.saveModel(nativeUbjModelPath1)
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
nativeUbjModelPath1))
|
||||
}
|
||||
|
||||
test("native json model file should store feature_name and feature_type") {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -333,21 +333,24 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
|
||||
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeJsonModelPath))
|
||||
|
||||
// test default "deprecated"
|
||||
// test default "ubj"
|
||||
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
|
||||
model.write.save(modelUbjPath)
|
||||
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
|
||||
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeDeprecatedModelPath))
|
||||
|
||||
// json file should be indifferent with ubj file
|
||||
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
|
||||
model.write.option("format", "json").save(modelJsonPath)
|
||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
|
||||
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
|
||||
model.nativeBooster.saveModel(nativeUbjModelPath)
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
}
|
||||
|
||||
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
|
||||
// test the deprecated format
|
||||
val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
|
||||
model.write.option("format", "deprecated").save(modelDeprecatedPath)
|
||||
|
||||
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
|
||||
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
|
||||
|
||||
assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
|
||||
nativeDeprecatedModelPath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||
*/
|
||||
public class Booster implements Serializable, KryoSerializable {
|
||||
public static final String DEFAULT_FORMAT = "deprecated";
|
||||
public static final String DEFAULT_FORMAT = "ubj";
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
// handle to the booster.
|
||||
private long handle = 0;
|
||||
@@ -788,8 +788,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
* Save model into raw byte array in the UBJSON ("ubj") format.
|
||||
*
|
||||
* @return the saved byte array
|
||||
* @throws XGBoostError native error
|
||||
|
||||
@@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path;
|
||||
public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private String modelSuffix = ".ubj";
|
||||
private Path checkpointPath; // directory for checkpoints
|
||||
private FileSystem fs;
|
||||
|
||||
|
||||
@@ -337,8 +337,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into a raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
* Save model into a raw byte array in the UBJSON ("ubj") format.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def toByteArray: Array[Byte] = {
|
||||
|
||||
Reference in New Issue
Block a user