[jvm-packages] Add the new device parameter. (#9385)

This commit is contained in:
Jiaming Yuan
2023-07-17 18:40:39 +08:00
committed by GitHub
parent 2caceb157d
commit f4fb2be101
15 changed files with 112 additions and 47 deletions

View File

@@ -40,20 +40,20 @@ object SparkMLlibPipeline {
val nativeModelPath = args(1)
val pipelineModelPath = args(2)
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
("gpu_hist", 1)
} else ("auto", 2)
val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
("cuda", 1)
} else ("cpu", 2)
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate()
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers)
run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers)
.show(false)
}
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
pipelineModelPath: String, treeMethod: String,
pipelineModelPath: String, device: String,
numWorkers: Int): DataFrame = {
// Load dataset
@@ -82,13 +82,14 @@ object SparkMLlibPipeline {
.setOutputCol("classIndex")
.fit(training)
val booster = new XGBoostClassifier(
Map("eta" -> 0.1f,
Map(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"tree_method" -> treeMethod
"device" -> device
)
)
booster.setFeaturesCol("features")

View File

@@ -31,18 +31,18 @@ object SparkTraining {
sys.exit(1)
}
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
("gpu_hist", 1)
} else ("auto", 2)
val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
("cuda", 1)
} else ("cpu", 2)
val spark = SparkSession.builder().getOrCreate()
val inputPath = args(0)
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers)
val results: DataFrame = run(spark, inputPath, device, numWorkers)
results.show()
}
private[spark] def run(spark: SparkSession, inputPath: String,
treeMethod: String, numWorkers: Int): DataFrame = {
device: String, numWorkers: Int): DataFrame = {
val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
@@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"tree_method" -> treeMethod,
"device" -> device,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").

View File

@@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
test("Smoke test for SparkMLlibPipeline example") {
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
"target/pipeline-model", "auto", 2)
"target/pipeline-model", "cpu", 2)
}
test("Smoke test for SparkTraining example") {
@@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
.config("spark.task.cpus", 1)
.getOrCreate()
SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2)
SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2)
}
}