From f4fb2be101034f4a43ce9e79cc0e1375906d23e7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 17 Jul 2023 18:40:39 +0800 Subject: [PATCH] [jvm-packages] Add the new `device` parameter. (#9385) --- doc/jvm/xgboost4j_spark_gpu_tutorial.rst | 15 +++---- .../example/spark/SparkMLlibPipeline.scala | 15 ++++--- .../scala/example/spark/SparkTraining.scala | 12 +++--- .../example/spark/SparkExamplesTest.scala | 4 +- .../dmlc/xgboost4j/gpu/java/BoosterTest.java | 3 +- .../scala/rapids/spark/GpuPreXGBoost.scala | 8 +++- .../rapids/spark/GpuXGBoostGeneralSuite.scala | 43 +++++++++++++++---- .../spark/GpuXGBoostRegressorSuite.scala | 32 +++++++++++--- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 11 +++-- .../scala/spark/XGBoostClassifier.scala | 2 + .../scala/spark/XGBoostRegressor.scala | 2 + .../scala/spark/params/BoosterParams.scala | 8 ++++ .../scala/spark/params/GeneralParams.scala | 2 +- .../scala/spark/XGBoostClassifierSuite.scala | 1 - .../dmlc/xgboost4j/java/BoosterImplTest.java | 1 - 15 files changed, 112 insertions(+), 47 deletions(-) diff --git a/doc/jvm/xgboost4j_spark_gpu_tutorial.rst b/doc/jvm/xgboost4j_spark_gpu_tutorial.rst index f3b97d9c3..7b80286ef 100644 --- a/doc/jvm/xgboost4j_spark_gpu_tutorial.rst +++ b/doc/jvm/xgboost4j_spark_gpu_tutorial.rst @@ -121,7 +121,7 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 100, - "tree_method" -> "gpu_hist", + "device" -> "cuda", "num_workers" -> 1) val featuresNames = schema.fieldNames.filter(name => name != labelName) @@ -130,15 +130,14 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie .setFeaturesCol(featuresNames) .setLabelCol(labelName) -The available parameters for training a XGBoost model can be found in :doc:`here `. -Similar to the XGBoost4J-Spark package, in addition to the default set of parameters, -XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be -consistent with Spark's MLlib naming convention. +The ``device`` parameter is for informing XGBoost that CUDA devices should be used instead of CPU. Unlike the single-node mode, GPUs are managed by spark instead of by XGBoost. Therefore, explicitly specified device ordinal like ``cuda:1`` is not support. + +The available parameters for training a XGBoost model can be found in :doc:`here `. Similar to the XGBoost4J-Spark package, in addition to the default set of parameters, XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be consistent with Spark's MLlib naming convention. Specifically, each parameter in :doc:`this page ` has its equivalent form in -XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you can pass -parameter just like what we did in the above code snippet (as ``max_depth`` wrapped in a Map), or -you can do it through setters in XGBoostClassifer: +XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you +can pass parameter just like what we did in the above code snippet (as ``max_depth`` +wrapped in a Map), or you can do it through setters in XGBoostClassifer: .. code-block:: scala diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala index b8da31c09..ae59af571 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala @@ -40,20 +40,20 @@ object SparkMLlibPipeline { val nativeModelPath = args(1) val pipelineModelPath = args(2) - val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") { - ("gpu_hist", 1) - } else ("auto", 2) + val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") { + ("cuda", 1) + } else ("cpu", 2) val spark = SparkSession .builder() .appName("XGBoost4J-Spark Pipeline Example") .getOrCreate() - run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers) + run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers) .show(false) } private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String, - pipelineModelPath: String, treeMethod: String, + pipelineModelPath: String, device: String, numWorkers: Int): DataFrame = { // Load dataset @@ -82,13 +82,14 @@ object SparkMLlibPipeline { .setOutputCol("classIndex") .fit(training) val booster = new XGBoostClassifier( - Map("eta" -> 0.1f, + Map( + "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 100, "num_workers" -> numWorkers, - "tree_method" -> treeMethod + "device" -> device ) ) booster.setFeaturesCol("features") diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala index a7886f524..67a9f7e23 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala @@ -31,18 +31,18 @@ object SparkTraining { sys.exit(1) } - val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") { - ("gpu_hist", 1) - } else ("auto", 2) + val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") { + ("cuda", 1) + } else ("cpu", 2) val spark = SparkSession.builder().getOrCreate() val inputPath = args(0) - val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers) + val results: DataFrame = run(spark, inputPath, device, numWorkers) results.show() } private[spark] def run(spark: SparkSession, inputPath: String, - treeMethod: String, numWorkers: Int): DataFrame = { + device: String, numWorkers: Int): DataFrame = { val schema = new StructType(Array( StructField("sepal length", DoubleType, true), StructField("sepal width", DoubleType, true), @@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String, "num_class" -> 3, "num_round" -> 100, "num_workers" -> numWorkers, - "tree_method" -> treeMethod, + "device" -> device, "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) val xgbClassifier = new XGBoostClassifier(xgbParam). setFeaturesCol("features"). diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkExamplesTest.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkExamplesTest.scala index f6cb700df..2e87bf066 100644 --- a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkExamplesTest.scala +++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkExamplesTest.scala @@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll { test("Smoke test for SparkMLlibPipeline example") { SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model", - "target/pipeline-model", "auto", 2) + "target/pipeline-model", "cpu", 2) } test("Smoke test for SparkTraining example") { @@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll { .config("spark.task.cpus", 1) .getOrCreate() - SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2) + SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2) } } diff --git a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java index 24a1491e1..ce830ef99 100644 --- a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java +++ b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java @@ -77,7 +77,8 @@ public class BoosterTest { put("objective", "binary:logistic"); put("num_round", round); put("num_workers", 1); - put("tree_method", "gpu_hist"); + put("tree_method", "hist"); + put("device", "cuda"); put("max_bin", maxBin); } }; diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index 9ff42e370..d34802805 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider { val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) = estimator match { case est: XGBoostEstimatorCommon => - require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"), - s"GPU train requires tree_method set to gpu_hist") + require( + est.isDefined(est.device) && + (est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) || + est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"), + s"GPU train requires `device` set to `cuda` or `gpu`." + ) val groupName = estimator match { case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) { regressor.getGroupCol } else "" diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala index 3d643761a..c731afb1d 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2022 by Contributors + Copyright (c) 2021-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,9 +50,12 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { withGpuSparkSession() { spark => import spark.implicits._ val trainingDf = trainingData.toDF(allColumnNames: _*) - val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", - "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist", - "features_cols" -> featureNames, "label_col" -> labelName) + val xgbParam = Map( + "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", + "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, + "tree_method" -> "hist", "device" -> "cuda", + "features_cols" -> featureNames, "label_col" -> labelName + ) new XGBoostClassifier(xgbParam) .fit(trainingDf) } @@ -65,8 +68,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1") - val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", - "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist") + val xgbParam = Map( + "eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", + "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, + "tree_method" -> "hist", "device" -> "cuda" + ) new XGBoostClassifier(xgbParam) .setFeaturesCol(featureNames) .setLabelCol(labelName) @@ -127,7 +133,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { } } - test("Throw exception when tree method is not set to gpu_hist") { + test("Throw exception when device is not set to cuda") { withGpuSparkSession() { spark => import spark.implicits._ val trainingDf = trainingData.toDF(allColumnNames: _*) @@ -139,12 +145,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { .setLabelCol(labelName) .fit(trainingDf) } - assert(thrown.getMessage.contains("GPU train requires tree_method set to gpu_hist")) + assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`")) } } test("Train with eval") { - withGpuSparkSession() { spark => import spark.implicits._ val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*) @@ -184,4 +189,24 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { } } + test("device ordinal should not be specified") { + withGpuSparkSession() { spark => + import spark.implicits._ + val trainingDf = trainingData.toDF(allColumnNames: _*) + val params = Map( + "objective" -> "multi:softprob", + "num_class" -> 3, + "num_round" -> 5, + "num_workers" -> 1 + ) + val thrown = intercept[IllegalArgumentException] { + new XGBoostClassifier(params) + .setFeaturesCol(featureNames) + .setLabelCol(labelName) + .setDevice("cuda:1") + .fit(trainingDf) + } + assert(thrown.getMessage.contains("`cuda` or `gpu`")) + } + } } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala index b8dca5d70..6c58ae9fc 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2022 by Contributors + Copyright (c) 2021-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { test("The transform result should be same for several runs on same model") { withGpuSparkSession(enableCsvConf()) { spark => val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror", - "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", + "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda", "features_cols" -> featureNames, "label_col" -> labelName) val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) @@ -54,10 +54,30 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { } } + test("Tree method gpu_hist still works") { + withGpuSparkSession(enableCsvConf()) { spark => + val params = Map( + "tree_method" -> "gpu_hist", + "features_cols" -> featureNames, + "label_col" -> labelName, + "num_round" -> 10, + "num_workers" -> 1 + ) + val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) + .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) + // Get a model + val model = new XGBoostRegressor(params).fit(originalDf) + val left = model.transform(testDf).collect() + val right = model.transform(testDf).collect() + // The left should be same with right + assert(compareResults(true, 0.000001, left, right)) + } + } + test("use weight") { withGpuSparkSession(enableCsvConf()) { spark => val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror", - "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", + "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda", "features_cols" -> featureNames, "label_col" -> labelName) val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) @@ -88,7 +108,8 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { val classifier = new XGBoostRegressor(xgbParam) .setFeaturesCol(featureNames) .setLabelCol(labelName) - .setTreeMethod("gpu_hist") + .setTreeMethod("hist") + .setDevice("cuda") (classifier.fit(rawInput), testDf) } @@ -175,7 +196,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { val classifier = new XGBoostRegressor(xgbParam) .setFeaturesCol(featureNames) .setLabelCol(labelName) - .setTreeMethod("gpu_hist") + .setDevice("cuda") classifier.fit(rawInput) } @@ -234,5 +255,4 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { assert(testDf.count() === ret.length) } } - } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 48b31a99f..5fc16ec09 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -73,7 +73,7 @@ private[scala] case class XGBoostExecutionParams( xgbInputParams: XGBoostExecutionInputParams, earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, cacheTrainingSet: Boolean, - treeMethod: Option[String], + device: Option[String], isLocal: Boolean, featureNames: Option[Array[String]], featureTypes: Option[Array[String]]) { @@ -180,6 +180,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s " as 'hist', 'approx', 'gpu_hist', and 'auto'") treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) } + val device: Option[String] = overridedParams.get("device") match { + case None => None + case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev) + } if (overridedParams.contains("train_test_ratio")) { logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + @@ -228,7 +232,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s inputParams, xgbExecEarlyStoppingParams, cacheTrainingSet, - treeMethod, + device, isLocal, featureNames, featureTypes @@ -318,7 +322,7 @@ object XGBoost extends Serializable { val externalCheckpointParams = xgbExecutionParam.checkpointParam var params = xgbExecutionParam.toMap - if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) { + if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) { val gpuId = if (xgbExecutionParam.isLocal) { // For local mode, force gpu id to primary device 0 @@ -328,6 +332,7 @@ object XGBoost extends Serializable { logger.info("Leveraging gpu device " + gpuId + " to train") params = params + ("device" -> s"cuda:$gpuId") } + val booster = if (makeCheckpoint) { SXGBoost.trainAndSaveCheckpoint( watches.toMap("train"), params, numRounds, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index fd4633a0d..ec8766e40 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -93,6 +93,8 @@ class XGBoostClassifier ( def setTreeMethod(value: String): this.type = set(treeMethod, value) + def setDevice(value: String): this.type = set(device, value) + def setGrowPolicy(value: String): this.type = set(growPolicy, value) def setMaxBins(value: Int): this.type = set(maxBins, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 99dbdc580..986e04c6b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -95,6 +95,8 @@ class XGBoostRegressor ( def setTreeMethod(value: String): this.type = set(treeMethod, value) + def setDevice(value: String): this.type = set(device, value) + def setGrowPolicy(value: String): this.type = set(growPolicy, value) def setMaxBins(value: Int): this.type = set(maxBins, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 21a77341c..61efc2865 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params { (value: String) => BoosterParams.supportedTreeMethods.contains(value)) final def getTreeMethod: String = $(treeMethod) + /** + * The device for running XGBoost algorithms, options: cpu, cuda + */ + final val device = new Param[String]( + this, "device", "The device for running XGBoost algorithms, options: cpu, cuda" + ) + + final def getDevice: String = $(device) /** * growth policy for fast histogram algorithm diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 3f387de9b..b85f4dc8b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -284,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params { (paramName == "updater" && paramValue != "grow_histmaker,prune" && paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) { throw new IllegalArgumentException(s"you specified $paramName as $paramValue," + - s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" + + s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" + s" grow_quantile_histmaker or grow_gpu_hist as the updater type") } val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 1290465ea..9b53c7642 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -469,7 +469,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS .setFeatureTypes(featureTypes) val model = xgb.fit(trainingDF) val modelStr = new String(model._booster.toByteArray("json")) - System.out.println(modelStr) val jsonModel = parseJson(modelStr) implicit val formats: Formats = DefaultFormats val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]] diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index d53c003a4..70966a38f 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -143,7 +143,6 @@ public class BoosterImplTest { booster.saveModel(temp.getAbsolutePath()); String modelString = new String(booster.toByteArray("json")); - System.out.println(modelString); Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));