[jvm-packages] Add the new device parameter. (#9385)
This commit is contained in:
parent
2caceb157d
commit
f4fb2be101
@ -121,7 +121,7 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"tree_method" -> "gpu_hist",
|
||||
"device" -> "cuda",
|
||||
"num_workers" -> 1)
|
||||
|
||||
val featuresNames = schema.fieldNames.filter(name => name != labelName)
|
||||
@ -130,15 +130,14 @@ To train a XGBoost model for classification, we need to claim a XGBoostClassifie
|
||||
.setFeaturesCol(featuresNames)
|
||||
.setLabelCol(labelName)
|
||||
|
||||
The available parameters for training a XGBoost model can be found in :doc:`here </parameter>`.
|
||||
Similar to the XGBoost4J-Spark package, in addition to the default set of parameters,
|
||||
XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be
|
||||
consistent with Spark's MLlib naming convention.
|
||||
The ``device`` parameter is for informing XGBoost that CUDA devices should be used instead of CPU. Unlike the single-node mode, GPUs are managed by spark instead of by XGBoost. Therefore, explicitly specified device ordinal like ``cuda:1`` is not support.
|
||||
|
||||
The available parameters for training a XGBoost model can be found in :doc:`here </parameter>`. Similar to the XGBoost4J-Spark package, in addition to the default set of parameters, XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be consistent with Spark's MLlib naming convention.
|
||||
|
||||
Specifically, each parameter in :doc:`this page </parameter>` has its equivalent form in
|
||||
XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you can pass
|
||||
parameter just like what we did in the above code snippet (as ``max_depth`` wrapped in a Map), or
|
||||
you can do it through setters in XGBoostClassifer:
|
||||
XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you
|
||||
can pass parameter just like what we did in the above code snippet (as ``max_depth``
|
||||
wrapped in a Map), or you can do it through setters in XGBoostClassifer:
|
||||
|
||||
.. code-block:: scala
|
||||
|
||||
|
||||
@ -40,20 +40,20 @@ object SparkMLlibPipeline {
|
||||
val nativeModelPath = args(1)
|
||||
val pipelineModelPath = args(2)
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("cuda", 1)
|
||||
} else ("cpu", 2)
|
||||
|
||||
val spark = SparkSession
|
||||
.builder()
|
||||
.appName("XGBoost4J-Spark Pipeline Example")
|
||||
.getOrCreate()
|
||||
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers)
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers)
|
||||
.show(false)
|
||||
}
|
||||
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
|
||||
pipelineModelPath: String, treeMethod: String,
|
||||
pipelineModelPath: String, device: String,
|
||||
numWorkers: Int): DataFrame = {
|
||||
|
||||
// Load dataset
|
||||
@ -82,13 +82,14 @@ object SparkMLlibPipeline {
|
||||
.setOutputCol("classIndex")
|
||||
.fit(training)
|
||||
val booster = new XGBoostClassifier(
|
||||
Map("eta" -> 0.1f,
|
||||
Map(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod
|
||||
"device" -> device
|
||||
)
|
||||
)
|
||||
booster.setFeaturesCol("features")
|
||||
|
||||
@ -31,18 +31,18 @@ object SparkTraining {
|
||||
sys.exit(1)
|
||||
}
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("cuda", 1)
|
||||
} else ("cpu", 2)
|
||||
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
val inputPath = args(0)
|
||||
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers)
|
||||
val results: DataFrame = run(spark, inputPath, device, numWorkers)
|
||||
results.show()
|
||||
}
|
||||
|
||||
private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
treeMethod: String, numWorkers: Int): DataFrame = {
|
||||
device: String, numWorkers: Int): DataFrame = {
|
||||
val schema = new StructType(Array(
|
||||
StructField("sepal length", DoubleType, true),
|
||||
StructField("sepal width", DoubleType, true),
|
||||
@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod,
|
||||
"device" -> device,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
|
||||
@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
|
||||
test("Smoke test for SparkMLlibPipeline example") {
|
||||
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
|
||||
"target/pipeline-model", "auto", 2)
|
||||
"target/pipeline-model", "cpu", 2)
|
||||
}
|
||||
|
||||
test("Smoke test for SparkTraining example") {
|
||||
@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
.config("spark.task.cpus", 1)
|
||||
.getOrCreate()
|
||||
|
||||
SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2)
|
||||
SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2)
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,7 +77,8 @@ public class BoosterTest {
|
||||
put("objective", "binary:logistic");
|
||||
put("num_round", round);
|
||||
put("num_workers", 1);
|
||||
put("tree_method", "gpu_hist");
|
||||
put("tree_method", "hist");
|
||||
put("device", "cuda");
|
||||
put("max_bin", maxBin);
|
||||
}
|
||||
};
|
||||
|
||||
@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
||||
estimator match {
|
||||
case est: XGBoostEstimatorCommon =>
|
||||
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||
s"GPU train requires tree_method set to gpu_hist")
|
||||
require(
|
||||
est.isDefined(est.device) &&
|
||||
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
|
||||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||
s"GPU train requires `device` set to `cuda` or `gpu`."
|
||||
)
|
||||
val groupName = estimator match {
|
||||
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
|
||||
regressor.getGroupCol } else ""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -50,9 +50,12 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val xgbParam = Map(
|
||||
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
||||
"tree_method" -> "hist", "device" -> "cuda",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName
|
||||
)
|
||||
new XGBoostClassifier(xgbParam)
|
||||
.fit(trainingDf)
|
||||
}
|
||||
@ -65,8 +68,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
|
||||
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
|
||||
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
||||
val xgbParam = Map(
|
||||
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
||||
"tree_method" -> "hist", "device" -> "cuda"
|
||||
)
|
||||
new XGBoostClassifier(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
@ -127,7 +133,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("Throw exception when tree method is not set to gpu_hist") {
|
||||
test("Throw exception when device is not set to cuda") {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||
@ -139,12 +145,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
.setLabelCol(labelName)
|
||||
.fit(trainingDf)
|
||||
}
|
||||
assert(thrown.getMessage.contains("GPU train requires tree_method set to gpu_hist"))
|
||||
assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`"))
|
||||
}
|
||||
}
|
||||
|
||||
test("Train with eval") {
|
||||
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
|
||||
@ -184,4 +189,24 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("device ordinal should not be specified") {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||
val params = Map(
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 5,
|
||||
"num_workers" -> 1
|
||||
)
|
||||
val thrown = intercept[IllegalArgumentException] {
|
||||
new XGBoostClassifier(params)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
.setDevice("cuda:1")
|
||||
.fit(trainingDf)
|
||||
}
|
||||
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -40,7 +40,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
test("The transform result should be same for several runs on same model") {
|
||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
@ -54,10 +54,30 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("Tree method gpu_hist still works") {
|
||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||
val params = Map(
|
||||
"tree_method" -> "gpu_hist",
|
||||
"features_cols" -> featureNames,
|
||||
"label_col" -> labelName,
|
||||
"num_round" -> 10,
|
||||
"num_workers" -> 1
|
||||
)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
// Get a model
|
||||
val model = new XGBoostRegressor(params).fit(originalDf)
|
||||
val left = model.transform(testDf).collect()
|
||||
val right = model.transform(testDf).collect()
|
||||
// The left should be same with right
|
||||
assert(compareResults(true, 0.000001, left, right))
|
||||
}
|
||||
}
|
||||
|
||||
test("use weight") {
|
||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
@ -88,7 +108,8 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
val classifier = new XGBoostRegressor(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
.setTreeMethod("gpu_hist")
|
||||
.setTreeMethod("hist")
|
||||
.setDevice("cuda")
|
||||
(classifier.fit(rawInput), testDf)
|
||||
}
|
||||
|
||||
@ -175,7 +196,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
val classifier = new XGBoostRegressor(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
.setTreeMethod("gpu_hist")
|
||||
.setDevice("cuda")
|
||||
classifier.fit(rawInput)
|
||||
}
|
||||
|
||||
@ -234,5 +255,4 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
assert(testDf.count() === ret.length)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ private[scala] case class XGBoostExecutionParams(
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
cacheTrainingSet: Boolean,
|
||||
treeMethod: Option[String],
|
||||
device: Option[String],
|
||||
isLocal: Boolean,
|
||||
featureNames: Option[Array[String]],
|
||||
featureTypes: Option[Array[String]]) {
|
||||
@ -180,6 +180,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||
}
|
||||
val device: Option[String] = overridedParams.get("device") match {
|
||||
case None => None
|
||||
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
|
||||
}
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
@ -228,7 +232,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
cacheTrainingSet,
|
||||
treeMethod,
|
||||
device,
|
||||
isLocal,
|
||||
featureNames,
|
||||
featureTypes
|
||||
@ -318,7 +322,7 @@ object XGBoost extends Serializable {
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
var params = xgbExecutionParam.toMap
|
||||
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
|
||||
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
|
||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||
// For local mode, force gpu id to primary device
|
||||
0
|
||||
@ -328,6 +332,7 @@ object XGBoost extends Serializable {
|
||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||
params = params + ("device" -> s"cuda:$gpuId")
|
||||
}
|
||||
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
watches.toMap("train"), params, numRounds,
|
||||
|
||||
@ -93,6 +93,8 @@ class XGBoostClassifier (
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setDevice(value: String): this.type = set(device, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
@ -95,6 +95,8 @@ class XGBoostRegressor (
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setDevice(value: String): this.type = set(device, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params {
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
/**
|
||||
* The device for running XGBoost algorithms, options: cpu, cuda
|
||||
*/
|
||||
final val device = new Param[String](
|
||||
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
|
||||
)
|
||||
|
||||
final def getDevice: String = $(device)
|
||||
|
||||
/**
|
||||
* growth policy for fast histogram algorithm
|
||||
|
||||
@ -284,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
|
||||
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
|
||||
@ -469,7 +469,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
||||
.setFeatureTypes(featureTypes)
|
||||
val model = xgb.fit(trainingDF)
|
||||
val modelStr = new String(model._booster.toByteArray("json"))
|
||||
System.out.println(modelStr)
|
||||
val jsonModel = parseJson(modelStr)
|
||||
implicit val formats: Formats = DefaultFormats
|
||||
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||
|
||||
@ -143,7 +143,6 @@ public class BoosterImplTest {
|
||||
booster.saveModel(temp.getAbsolutePath());
|
||||
|
||||
String modelString = new String(booster.toByteArray("json"));
|
||||
System.out.println(modelString);
|
||||
|
||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user