[jvm-packages] Add the new device parameter. (#9385)

This commit is contained in:
Jiaming Yuan 2023-07-17 18:40:39 +08:00 committed by GitHub
parent 2caceb157d
commit f4fb2be101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 112 additions and 47 deletions

View File

@ -121,7 +121,7 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie
"objective" -> "multi:softprob", "objective" -> "multi:softprob",
"num_class" -> 3, "num_class" -> 3,
"num_round" -> 100, "num_round" -> 100,
"tree_method" -> "gpu_hist", "device" -> "cuda",
"num_workers" -> 1) "num_workers" -> 1)
val featuresNames = schema.fieldNames.filter(name => name != labelName) val featuresNames = schema.fieldNames.filter(name => name != labelName)
@ -130,15 +130,14 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie
.setFeaturesCol(featuresNames) .setFeaturesCol(featuresNames)
.setLabelCol(labelName) .setLabelCol(labelName)
The available parameters for training a XGBoost model can be found in :doc:`here </parameter>`. The ``device`` parameter is for informing XGBoost that CUDA devices should be used instead of CPU. Unlike the single-node mode, GPUs are managed by spark instead of by XGBoost. Therefore, explicitly specified device ordinal like ``cuda:1`` is not support.
Similar to the XGBoost4J-Spark package, in addition to the default set of parameters,
XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be The available parameters for training a XGBoost model can be found in :doc:`here </parameter>`. Similar to the XGBoost4J-Spark package, in addition to the default set of parameters, XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be consistent with Spark's MLlib naming convention.
consistent with Spark's MLlib naming convention.
Specifically, each parameter in :doc:`this page </parameter>` has its equivalent form in Specifically, each parameter in :doc:`this page </parameter>` has its equivalent form in
XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you can pass XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you
parameter just like what we did in the above code snippet (as ``max_depth`` wrapped in a Map), or can pass parameter just like what we did in the above code snippet (as ``max_depth``
you can do it through setters in XGBoostClassifer: wrapped in a Map), or you can do it through setters in XGBoostClassifer:
.. code-block:: scala .. code-block:: scala

View File

@ -40,20 +40,20 @@ object SparkMLlibPipeline {
val nativeModelPath = args(1) val nativeModelPath = args(1)
val pipelineModelPath = args(2) val pipelineModelPath = args(2)
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") { val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
("gpu_hist", 1) ("cuda", 1)
} else ("auto", 2) } else ("cpu", 2)
val spark = SparkSession val spark = SparkSession
.builder() .builder()
.appName("XGBoost4J-Spark Pipeline Example") .appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate() .getOrCreate()
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers) run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers)
.show(false) .show(false)
} }
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String, private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
pipelineModelPath: String, treeMethod: String, pipelineModelPath: String, device: String,
numWorkers: Int): DataFrame = { numWorkers: Int): DataFrame = {
// Load dataset // Load dataset
@ -82,13 +82,14 @@ object SparkMLlibPipeline {
.setOutputCol("classIndex") .setOutputCol("classIndex")
.fit(training) .fit(training)
val booster = new XGBoostClassifier( val booster = new XGBoostClassifier(
Map("eta" -> 0.1f, Map(
"eta" -> 0.1f,
"max_depth" -> 2, "max_depth" -> 2,
"objective" -> "multi:softprob", "objective" -> "multi:softprob",
"num_class" -> 3, "num_class" -> 3,
"num_round" -> 100, "num_round" -> 100,
"num_workers" -> numWorkers, "num_workers" -> numWorkers,
"tree_method" -> treeMethod "device" -> device
) )
) )
booster.setFeaturesCol("features") booster.setFeaturesCol("features")

View File

@ -31,18 +31,18 @@ object SparkTraining {
sys.exit(1) sys.exit(1)
} }
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") { val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
("gpu_hist", 1) ("cuda", 1)
} else ("auto", 2) } else ("cpu", 2)
val spark = SparkSession.builder().getOrCreate() val spark = SparkSession.builder().getOrCreate()
val inputPath = args(0) val inputPath = args(0)
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers) val results: DataFrame = run(spark, inputPath, device, numWorkers)
results.show() results.show()
} }
private[spark] def run(spark: SparkSession, inputPath: String, private[spark] def run(spark: SparkSession, inputPath: String,
treeMethod: String, numWorkers: Int): DataFrame = { device: String, numWorkers: Int): DataFrame = {
val schema = new StructType(Array( val schema = new StructType(Array(
StructField("sepal length", DoubleType, true), StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true), StructField("sepal width", DoubleType, true),
@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
"num_class" -> 3, "num_class" -> 3,
"num_round" -> 100, "num_round" -> 100,
"num_workers" -> numWorkers, "num_workers" -> numWorkers,
"tree_method" -> treeMethod, "device" -> device,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam). val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features"). setFeaturesCol("features").

View File

@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
test("Smoke test for SparkMLlibPipeline example") { test("Smoke test for SparkMLlibPipeline example") {
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model", SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
"target/pipeline-model", "auto", 2) "target/pipeline-model", "cpu", 2)
} }
test("Smoke test for SparkTraining example") { test("Smoke test for SparkTraining example") {
@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
.config("spark.task.cpus", 1) .config("spark.task.cpus", 1)
.getOrCreate() .getOrCreate()
SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2) SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2)
} }
} }

View File

@ -77,7 +77,8 @@ public class BoosterTest {
put("objective", "binary:logistic"); put("objective", "binary:logistic");
put("num_round", round); put("num_round", round);
put("num_workers", 1); put("num_workers", 1);
put("tree_method", "gpu_hist"); put("tree_method", "hist");
put("device", "cuda");
put("max_bin", maxBin); put("max_bin", maxBin);
} }
}; };

View File

@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) = val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match { estimator match {
case est: XGBoostEstimatorCommon => case est: XGBoostEstimatorCommon =>
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"), require(
s"GPU train requires tree_method set to gpu_hist") est.isDefined(est.device) &&
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires `device` set to `cuda` or `gpu`."
)
val groupName = estimator match { val groupName = estimator match {
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) { case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
regressor.getGroupCol } else "" regressor.getGroupCol } else ""

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2021-2022 by Contributors Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -50,9 +50,12 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
withGpuSparkSession() { spark => withGpuSparkSession() { spark =>
import spark.implicits._ import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*) val trainingDf = trainingData.toDF(allColumnNames: _*)
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", val xgbParam = Map(
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist", "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"features_cols" -> featureNames, "label_col" -> labelName) "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
"tree_method" -> "hist", "device" -> "cuda",
"features_cols" -> featureNames, "label_col" -> labelName
)
new XGBoostClassifier(xgbParam) new XGBoostClassifier(xgbParam)
.fit(trainingDf) .fit(trainingDf)
} }
@ -65,8 +68,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1") trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", val xgbParam = Map(
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist") "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
"tree_method" -> "hist", "device" -> "cuda"
)
new XGBoostClassifier(xgbParam) new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames) .setFeaturesCol(featureNames)
.setLabelCol(labelName) .setLabelCol(labelName)
@ -127,7 +133,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
} }
} }
test("Throw exception when tree method is not set to gpu_hist") { test("Throw exception when device is not set to cuda") {
withGpuSparkSession() { spark => withGpuSparkSession() { spark =>
import spark.implicits._ import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*) val trainingDf = trainingData.toDF(allColumnNames: _*)
@ -139,12 +145,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
.setLabelCol(labelName) .setLabelCol(labelName)
.fit(trainingDf) .fit(trainingDf)
} }
assert(thrown.getMessage.contains("GPU train requires tree_method set to gpu_hist")) assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`"))
} }
} }
test("Train with eval") { test("Train with eval") {
withGpuSparkSession() { spark => withGpuSparkSession() { spark =>
import spark.implicits._ import spark.implicits._
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*) val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
@ -184,4 +189,24 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
} }
} }
test("device ordinal should not be specified") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val params = Map(
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 5,
"num_workers" -> 1
)
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(params)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setDevice("cuda:1")
.fit(trainingDf)
}
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
}
}
} }

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2021-2022 by Contributors Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -40,7 +40,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
test("The transform result should be same for several runs on same model") { test("The transform result should be same for several runs on same model") {
withGpuSparkSession(enableCsvConf()) { spark => withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror", val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
"features_cols" -> featureNames, "label_col" -> labelName) "features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
@ -54,10 +54,30 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
} }
} }
test("Tree method gpu_hist still works") {
withGpuSparkSession(enableCsvConf()) { spark =>
val params = Map(
"tree_method" -> "gpu_hist",
"features_cols" -> featureNames,
"label_col" -> labelName,
"num_round" -> 10,
"num_workers" -> 1
)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model
val model = new XGBoostRegressor(params).fit(originalDf)
val left = model.transform(testDf).collect()
val right = model.transform(testDf).collect()
// The left should be same with right
assert(compareResults(true, 0.000001, left, right))
}
}
test("use weight") { test("use weight") {
withGpuSparkSession(enableCsvConf()) { spark => withGpuSparkSession(enableCsvConf()) { spark =>
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror", val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
"features_cols" -> featureNames, "label_col" -> labelName) "features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
@ -88,7 +108,8 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
val classifier = new XGBoostRegressor(xgbParam) val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCol(featureNames) .setFeaturesCol(featureNames)
.setLabelCol(labelName) .setLabelCol(labelName)
.setTreeMethod("gpu_hist") .setTreeMethod("hist")
.setDevice("cuda")
(classifier.fit(rawInput), testDf) (classifier.fit(rawInput), testDf)
} }
@ -175,7 +196,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
val classifier = new XGBoostRegressor(xgbParam) val classifier = new XGBoostRegressor(xgbParam)
.setFeaturesCol(featureNames) .setFeaturesCol(featureNames)
.setLabelCol(labelName) .setLabelCol(labelName)
.setTreeMethod("gpu_hist") .setDevice("cuda")
classifier.fit(rawInput) classifier.fit(rawInput)
} }
@ -234,5 +255,4 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
assert(testDf.count() === ret.length) assert(testDf.count() === ret.length)
} }
} }
} }

View File

@ -73,7 +73,7 @@ private[scala] case class XGBoostExecutionParams(
xgbInputParams: XGBoostExecutionInputParams, xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean, cacheTrainingSet: Boolean,
treeMethod: Option[String], device: Option[String],
isLocal: Boolean, isLocal: Boolean,
featureNames: Option[Array[String]], featureNames: Option[Array[String]],
featureTypes: Option[Array[String]]) { featureTypes: Option[Array[String]]) {
@ -180,6 +180,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" as 'hist', 'approx', 'gpu_hist', and 'auto'") " as 'hist', 'approx', 'gpu_hist', and 'auto'")
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
} }
val device: Option[String] = overridedParams.get("device") match {
case None => None
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
}
if (overridedParams.contains("train_test_ratio")) { if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
@ -228,7 +232,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
inputParams, inputParams,
xgbExecEarlyStoppingParams, xgbExecEarlyStoppingParams,
cacheTrainingSet, cacheTrainingSet,
treeMethod, device,
isLocal, isLocal,
featureNames, featureNames,
featureTypes featureTypes
@ -318,7 +322,7 @@ object XGBoost extends Serializable {
val externalCheckpointParams = xgbExecutionParam.checkpointParam val externalCheckpointParams = xgbExecutionParam.checkpointParam
var params = xgbExecutionParam.toMap var params = xgbExecutionParam.toMap
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) { if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
val gpuId = if (xgbExecutionParam.isLocal) { val gpuId = if (xgbExecutionParam.isLocal) {
// For local mode, force gpu id to primary device // For local mode, force gpu id to primary device
0 0
@ -328,6 +332,7 @@ object XGBoost extends Serializable {
logger.info("Leveraging gpu device " + gpuId + " to train") logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("device" -> s"cuda:$gpuId") params = params + ("device" -> s"cuda:$gpuId")
} }
val booster = if (makeCheckpoint) { val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint( SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), params, numRounds, watches.toMap("train"), params, numRounds,

View File

@ -93,6 +93,8 @@ class XGBoostClassifier (
def setTreeMethod(value: String): this.type = set(treeMethod, value) def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setDevice(value: String): this.type = set(device, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value) def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value) def setMaxBins(value: Int): this.type = set(maxBins, value)

View File

@ -95,6 +95,8 @@ class XGBoostRegressor (
def setTreeMethod(value: String): this.type = set(treeMethod, value) def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setDevice(value: String): this.type = set(device, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value) def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value) def setMaxBins(value: Int): this.type = set(maxBins, value)

View File

@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params {
(value: String) => BoosterParams.supportedTreeMethods.contains(value)) (value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod) final def getTreeMethod: String = $(treeMethod)
/**
* The device for running XGBoost algorithms, options: cpu, cuda
*/
final val device = new Param[String](
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
)
final def getDevice: String = $(device)
/** /**
* growth policy for fast histogram algorithm * growth policy for fast histogram algorithm

View File

@ -284,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params {
(paramName == "updater" && paramValue != "grow_histmaker,prune" && (paramName == "updater" && paramValue != "grow_histmaker,prune" &&
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) { paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," + throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" + s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
s" grow_quantile_histmaker or grow_gpu_hist as the updater type") s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
} }
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)

View File

@ -469,7 +469,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
.setFeatureTypes(featureTypes) .setFeatureTypes(featureTypes)
val model = xgb.fit(trainingDF) val model = xgb.fit(trainingDF)
val modelStr = new String(model._booster.toByteArray("json")) val modelStr = new String(model._booster.toByteArray("json"))
System.out.println(modelStr)
val jsonModel = parseJson(modelStr) val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]] val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]

View File

@ -143,7 +143,6 @@ public class BoosterImplTest {
booster.saveModel(temp.getAbsolutePath()); booster.saveModel(temp.getAbsolutePath());
String modelString = new String(booster.toByteArray("json")); String modelString = new String(booster.toByteArray("json"));
System.out.println(modelString);
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj"))); assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));