[jvm-packages] Add the new device parameter. (#9385)
This commit is contained in:
parent
2caceb157d
commit
f4fb2be101
@ -121,7 +121,7 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie
|
|||||||
"objective" -> "multi:softprob",
|
"objective" -> "multi:softprob",
|
||||||
"num_class" -> 3,
|
"num_class" -> 3,
|
||||||
"num_round" -> 100,
|
"num_round" -> 100,
|
||||||
"tree_method" -> "gpu_hist",
|
"device" -> "cuda",
|
||||||
"num_workers" -> 1)
|
"num_workers" -> 1)
|
||||||
|
|
||||||
val featuresNames = schema.fieldNames.filter(name => name != labelName)
|
val featuresNames = schema.fieldNames.filter(name => name != labelName)
|
||||||
@ -130,15 +130,14 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie
|
|||||||
.setFeaturesCol(featuresNames)
|
.setFeaturesCol(featuresNames)
|
||||||
.setLabelCol(labelName)
|
.setLabelCol(labelName)
|
||||||
|
|
||||||
The available parameters for training a XGBoost model can be found in :doc:`here </parameter>`.
|
The ``device`` parameter is for informing XGBoost that CUDA devices should be used instead of CPU. Unlike the single-node mode, GPUs are managed by spark instead of by XGBoost. Therefore, explicitly specified device ordinal like ``cuda:1`` is not support.
|
||||||
Similar to the XGBoost4J-Spark package, in addition to the default set of parameters,
|
|
||||||
XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be
|
The available parameters for training a XGBoost model can be found in :doc:`here </parameter>`. Similar to the XGBoost4J-Spark package, in addition to the default set of parameters, XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be consistent with Spark's MLlib naming convention.
|
||||||
consistent with Spark's MLlib naming convention.
|
|
||||||
|
|
||||||
Specifically, each parameter in :doc:`this page </parameter>` has its equivalent form in
|
Specifically, each parameter in :doc:`this page </parameter>` has its equivalent form in
|
||||||
XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you can pass
|
XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you
|
||||||
parameter just like what we did in the above code snippet (as ``max_depth`` wrapped in a Map), or
|
can pass parameter just like what we did in the above code snippet (as ``max_depth``
|
||||||
you can do it through setters in XGBoostClassifer:
|
wrapped in a Map), or you can do it through setters in XGBoostClassifer:
|
||||||
|
|
||||||
.. code-block:: scala
|
.. code-block:: scala
|
||||||
|
|
||||||
|
|||||||
@ -40,20 +40,20 @@ object SparkMLlibPipeline {
|
|||||||
val nativeModelPath = args(1)
|
val nativeModelPath = args(1)
|
||||||
val pipelineModelPath = args(2)
|
val pipelineModelPath = args(2)
|
||||||
|
|
||||||
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||||
("gpu_hist", 1)
|
("cuda", 1)
|
||||||
} else ("auto", 2)
|
} else ("cpu", 2)
|
||||||
|
|
||||||
val spark = SparkSession
|
val spark = SparkSession
|
||||||
.builder()
|
.builder()
|
||||||
.appName("XGBoost4J-Spark Pipeline Example")
|
.appName("XGBoost4J-Spark Pipeline Example")
|
||||||
.getOrCreate()
|
.getOrCreate()
|
||||||
|
|
||||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers)
|
run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers)
|
||||||
.show(false)
|
.show(false)
|
||||||
}
|
}
|
||||||
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
|
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
|
||||||
pipelineModelPath: String, treeMethod: String,
|
pipelineModelPath: String, device: String,
|
||||||
numWorkers: Int): DataFrame = {
|
numWorkers: Int): DataFrame = {
|
||||||
|
|
||||||
// Load dataset
|
// Load dataset
|
||||||
@ -82,13 +82,14 @@ object SparkMLlibPipeline {
|
|||||||
.setOutputCol("classIndex")
|
.setOutputCol("classIndex")
|
||||||
.fit(training)
|
.fit(training)
|
||||||
val booster = new XGBoostClassifier(
|
val booster = new XGBoostClassifier(
|
||||||
Map("eta" -> 0.1f,
|
Map(
|
||||||
|
"eta" -> 0.1f,
|
||||||
"max_depth" -> 2,
|
"max_depth" -> 2,
|
||||||
"objective" -> "multi:softprob",
|
"objective" -> "multi:softprob",
|
||||||
"num_class" -> 3,
|
"num_class" -> 3,
|
||||||
"num_round" -> 100,
|
"num_round" -> 100,
|
||||||
"num_workers" -> numWorkers,
|
"num_workers" -> numWorkers,
|
||||||
"tree_method" -> treeMethod
|
"device" -> device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
booster.setFeaturesCol("features")
|
booster.setFeaturesCol("features")
|
||||||
|
|||||||
@ -31,18 +31,18 @@ object SparkTraining {
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||||
("gpu_hist", 1)
|
("cuda", 1)
|
||||||
} else ("auto", 2)
|
} else ("cpu", 2)
|
||||||
|
|
||||||
val spark = SparkSession.builder().getOrCreate()
|
val spark = SparkSession.builder().getOrCreate()
|
||||||
val inputPath = args(0)
|
val inputPath = args(0)
|
||||||
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers)
|
val results: DataFrame = run(spark, inputPath, device, numWorkers)
|
||||||
results.show()
|
results.show()
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def run(spark: SparkSession, inputPath: String,
|
private[spark] def run(spark: SparkSession, inputPath: String,
|
||||||
treeMethod: String, numWorkers: Int): DataFrame = {
|
device: String, numWorkers: Int): DataFrame = {
|
||||||
val schema = new StructType(Array(
|
val schema = new StructType(Array(
|
||||||
StructField("sepal length", DoubleType, true),
|
StructField("sepal length", DoubleType, true),
|
||||||
StructField("sepal width", DoubleType, true),
|
StructField("sepal width", DoubleType, true),
|
||||||
@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
|
|||||||
"num_class" -> 3,
|
"num_class" -> 3,
|
||||||
"num_round" -> 100,
|
"num_round" -> 100,
|
||||||
"num_workers" -> numWorkers,
|
"num_workers" -> numWorkers,
|
||||||
"tree_method" -> treeMethod,
|
"device" -> device,
|
||||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||||
setFeaturesCol("features").
|
setFeaturesCol("features").
|
||||||
|
|||||||
@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
|||||||
|
|
||||||
test("Smoke test for SparkMLlibPipeline example") {
|
test("Smoke test for SparkMLlibPipeline example") {
|
||||||
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
|
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
|
||||||
"target/pipeline-model", "auto", 2)
|
"target/pipeline-model", "cpu", 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Smoke test for SparkTraining example") {
|
test("Smoke test for SparkTraining example") {
|
||||||
@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
|||||||
.config("spark.task.cpus", 1)
|
.config("spark.task.cpus", 1)
|
||||||
.getOrCreate()
|
.getOrCreate()
|
||||||
|
|
||||||
SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2)
|
SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -77,7 +77,8 @@ public class BoosterTest {
|
|||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
put("num_round", round);
|
put("num_round", round);
|
||||||
put("num_workers", 1);
|
put("num_workers", 1);
|
||||||
put("tree_method", "gpu_hist");
|
put("tree_method", "hist");
|
||||||
|
put("device", "cuda");
|
||||||
put("max_bin", maxBin);
|
put("max_bin", maxBin);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
||||||
estimator match {
|
estimator match {
|
||||||
case est: XGBoostEstimatorCommon =>
|
case est: XGBoostEstimatorCommon =>
|
||||||
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
require(
|
||||||
s"GPU train requires tree_method set to gpu_hist")
|
est.isDefined(est.device) &&
|
||||||
|
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
|
||||||
|
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||||
|
s"GPU train requires `device` set to `cuda` or `gpu`."
|
||||||
|
)
|
||||||
val groupName = estimator match {
|
val groupName = estimator match {
|
||||||
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
|
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
|
||||||
regressor.getGroupCol } else ""
|
regressor.getGroupCol } else ""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021-2022 by Contributors
|
Copyright (c) 2021-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -50,9 +50,12 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|||||||
withGpuSparkSession() { spark =>
|
withGpuSparkSession() { spark =>
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
val xgbParam = Map(
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
||||||
|
"tree_method" -> "hist", "device" -> "cuda",
|
||||||
|
"features_cols" -> featureNames, "label_col" -> labelName
|
||||||
|
)
|
||||||
new XGBoostClassifier(xgbParam)
|
new XGBoostClassifier(xgbParam)
|
||||||
.fit(trainingDf)
|
.fit(trainingDf)
|
||||||
}
|
}
|
||||||
@ -65,8 +68,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|||||||
|
|
||||||
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
|
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
|
||||||
|
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
val xgbParam = Map(
|
||||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||||
|
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
||||||
|
"tree_method" -> "hist", "device" -> "cuda"
|
||||||
|
)
|
||||||
new XGBoostClassifier(xgbParam)
|
new XGBoostClassifier(xgbParam)
|
||||||
.setFeaturesCol(featureNames)
|
.setFeaturesCol(featureNames)
|
||||||
.setLabelCol(labelName)
|
.setLabelCol(labelName)
|
||||||
@ -127,7 +133,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Throw exception when tree method is not set to gpu_hist") {
|
test("Throw exception when device is not set to cuda") {
|
||||||
withGpuSparkSession() { spark =>
|
withGpuSparkSession() { spark =>
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||||
@ -139,12 +145,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|||||||
.setLabelCol(labelName)
|
.setLabelCol(labelName)
|
||||||
.fit(trainingDf)
|
.fit(trainingDf)
|
||||||
}
|
}
|
||||||
assert(thrown.getMessage.contains("GPU train requires tree_method set to gpu_hist"))
|
assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Train with eval") {
|
test("Train with eval") {
|
||||||
|
|
||||||
withGpuSparkSession() { spark =>
|
withGpuSparkSession() { spark =>
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
|
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
|
||||||
@ -184,4 +189,24 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("device ordinal should not be specified") {
|
||||||
|
withGpuSparkSession() { spark =>
|
||||||
|
import spark.implicits._
|
||||||
|
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||||
|
val params = Map(
|
||||||
|
"objective" -> "multi:softprob",
|
||||||
|
"num_class" -> 3,
|
||||||
|
"num_round" -> 5,
|
||||||
|
"num_workers" -> 1
|
||||||
|
)
|
||||||
|
val thrown = intercept[IllegalArgumentException] {
|
||||||
|
new XGBoostClassifier(params)
|
||||||
|
.setFeaturesCol(featureNames)
|
||||||
|
.setLabelCol(labelName)
|
||||||
|
.setDevice("cuda:1")
|
||||||
|
.fit(trainingDf)
|
||||||
|
}
|
||||||
|
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021-2022 by Contributors
|
Copyright (c) 2021-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -40,7 +40,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
test("The transform result should be same for several runs on same model") {
|
test("The transform result should be same for several runs on same model") {
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||||
@ -54,10 +54,30 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Tree method gpu_hist still works") {
|
||||||
|
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||||
|
val params = Map(
|
||||||
|
"tree_method" -> "gpu_hist",
|
||||||
|
"features_cols" -> featureNames,
|
||||||
|
"label_col" -> labelName,
|
||||||
|
"num_round" -> 10,
|
||||||
|
"num_workers" -> 1
|
||||||
|
)
|
||||||
|
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||||
|
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||||
|
// Get a model
|
||||||
|
val model = new XGBoostRegressor(params).fit(originalDf)
|
||||||
|
val left = model.transform(testDf).collect()
|
||||||
|
val right = model.transform(testDf).collect()
|
||||||
|
// The left should be same with right
|
||||||
|
assert(compareResults(true, 0.000001, left, right))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("use weight") {
|
test("use weight") {
|
||||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
||||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
||||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||||
@ -88,7 +108,8 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
val classifier = new XGBoostRegressor(xgbParam)
|
val classifier = new XGBoostRegressor(xgbParam)
|
||||||
.setFeaturesCol(featureNames)
|
.setFeaturesCol(featureNames)
|
||||||
.setLabelCol(labelName)
|
.setLabelCol(labelName)
|
||||||
.setTreeMethod("gpu_hist")
|
.setTreeMethod("hist")
|
||||||
|
.setDevice("cuda")
|
||||||
(classifier.fit(rawInput), testDf)
|
(classifier.fit(rawInput), testDf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,7 +196,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
val classifier = new XGBoostRegressor(xgbParam)
|
val classifier = new XGBoostRegressor(xgbParam)
|
||||||
.setFeaturesCol(featureNames)
|
.setFeaturesCol(featureNames)
|
||||||
.setLabelCol(labelName)
|
.setLabelCol(labelName)
|
||||||
.setTreeMethod("gpu_hist")
|
.setDevice("cuda")
|
||||||
classifier.fit(rawInput)
|
classifier.fit(rawInput)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -234,5 +255,4 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
|||||||
assert(testDf.count() === ret.length)
|
assert(testDf.count() === ret.length)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -73,7 +73,7 @@ private[scala] case class XGBoostExecutionParams(
|
|||||||
xgbInputParams: XGBoostExecutionInputParams,
|
xgbInputParams: XGBoostExecutionInputParams,
|
||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
cacheTrainingSet: Boolean,
|
cacheTrainingSet: Boolean,
|
||||||
treeMethod: Option[String],
|
device: Option[String],
|
||||||
isLocal: Boolean,
|
isLocal: Boolean,
|
||||||
featureNames: Option[Array[String]],
|
featureNames: Option[Array[String]],
|
||||||
featureTypes: Option[Array[String]]) {
|
featureTypes: Option[Array[String]]) {
|
||||||
@ -180,6 +180,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||||
}
|
}
|
||||||
|
val device: Option[String] = overridedParams.get("device") match {
|
||||||
|
case None => None
|
||||||
|
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
|
||||||
|
}
|
||||||
if (overridedParams.contains("train_test_ratio")) {
|
if (overridedParams.contains("train_test_ratio")) {
|
||||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||||
@ -228,7 +232,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
inputParams,
|
inputParams,
|
||||||
xgbExecEarlyStoppingParams,
|
xgbExecEarlyStoppingParams,
|
||||||
cacheTrainingSet,
|
cacheTrainingSet,
|
||||||
treeMethod,
|
device,
|
||||||
isLocal,
|
isLocal,
|
||||||
featureNames,
|
featureNames,
|
||||||
featureTypes
|
featureTypes
|
||||||
@ -318,7 +322,7 @@ object XGBoost extends Serializable {
|
|||||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||||
|
|
||||||
var params = xgbExecutionParam.toMap
|
var params = xgbExecutionParam.toMap
|
||||||
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
|
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
|
||||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||||
// For local mode, force gpu id to primary device
|
// For local mode, force gpu id to primary device
|
||||||
0
|
0
|
||||||
@ -328,6 +332,7 @@ object XGBoost extends Serializable {
|
|||||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||||
params = params + ("device" -> s"cuda:$gpuId")
|
params = params + ("device" -> s"cuda:$gpuId")
|
||||||
}
|
}
|
||||||
|
|
||||||
val booster = if (makeCheckpoint) {
|
val booster = if (makeCheckpoint) {
|
||||||
SXGBoost.trainAndSaveCheckpoint(
|
SXGBoost.trainAndSaveCheckpoint(
|
||||||
watches.toMap("train"), params, numRounds,
|
watches.toMap("train"), params, numRounds,
|
||||||
|
|||||||
@ -93,6 +93,8 @@ class XGBoostClassifier (
|
|||||||
|
|
||||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||||
|
|
||||||
|
def setDevice(value: String): this.type = set(device, value)
|
||||||
|
|
||||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||||
|
|
||||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||||
|
|||||||
@ -95,6 +95,8 @@ class XGBoostRegressor (
|
|||||||
|
|
||||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||||
|
|
||||||
|
def setDevice(value: String): this.type = set(device, value)
|
||||||
|
|
||||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||||
|
|
||||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||||
|
|||||||
@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||||
|
|
||||||
final def getTreeMethod: String = $(treeMethod)
|
final def getTreeMethod: String = $(treeMethod)
|
||||||
|
/**
|
||||||
|
* The device for running XGBoost algorithms, options: cpu, cuda
|
||||||
|
*/
|
||||||
|
final val device = new Param[String](
|
||||||
|
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
final def getDevice: String = $(device)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* growth policy for fast histogram algorithm
|
* growth policy for fast histogram algorithm
|
||||||
|
|||||||
@ -284,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params {
|
|||||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||||
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
||||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
|
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
|
||||||
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
||||||
}
|
}
|
||||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||||
|
|||||||
@ -469,7 +469,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
|||||||
.setFeatureTypes(featureTypes)
|
.setFeatureTypes(featureTypes)
|
||||||
val model = xgb.fit(trainingDF)
|
val model = xgb.fit(trainingDF)
|
||||||
val modelStr = new String(model._booster.toByteArray("json"))
|
val modelStr = new String(model._booster.toByteArray("json"))
|
||||||
System.out.println(modelStr)
|
|
||||||
val jsonModel = parseJson(modelStr)
|
val jsonModel = parseJson(modelStr)
|
||||||
implicit val formats: Formats = DefaultFormats
|
implicit val formats: Formats = DefaultFormats
|
||||||
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||||
|
|||||||
@ -143,7 +143,6 @@ public class BoosterImplTest {
|
|||||||
booster.saveModel(temp.getAbsolutePath());
|
booster.saveModel(temp.getAbsolutePath());
|
||||||
|
|
||||||
String modelString = new String(booster.toByteArray("json"));
|
String modelString = new String(booster.toByteArray("json"));
|
||||||
System.out.println(modelString);
|
|
||||||
|
|
||||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||||
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user