[jvm-packages] Add the new device parameter. (#9385)
This commit is contained in:
@@ -40,20 +40,20 @@ object SparkMLlibPipeline {
|
||||
val nativeModelPath = args(1)
|
||||
val pipelineModelPath = args(2)
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("cuda", 1)
|
||||
} else ("cpu", 2)
|
||||
|
||||
val spark = SparkSession
|
||||
.builder()
|
||||
.appName("XGBoost4J-Spark Pipeline Example")
|
||||
.getOrCreate()
|
||||
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers)
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers)
|
||||
.show(false)
|
||||
}
|
||||
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
|
||||
pipelineModelPath: String, treeMethod: String,
|
||||
pipelineModelPath: String, device: String,
|
||||
numWorkers: Int): DataFrame = {
|
||||
|
||||
// Load dataset
|
||||
@@ -82,13 +82,14 @@ object SparkMLlibPipeline {
|
||||
.setOutputCol("classIndex")
|
||||
.fit(training)
|
||||
val booster = new XGBoostClassifier(
|
||||
Map("eta" -> 0.1f,
|
||||
Map(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod
|
||||
"device" -> device
|
||||
)
|
||||
)
|
||||
booster.setFeaturesCol("features")
|
||||
|
||||
@@ -31,18 +31,18 @@ object SparkTraining {
|
||||
sys.exit(1)
|
||||
}
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("cuda", 1)
|
||||
} else ("cpu", 2)
|
||||
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
val inputPath = args(0)
|
||||
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers)
|
||||
val results: DataFrame = run(spark, inputPath, device, numWorkers)
|
||||
results.show()
|
||||
}
|
||||
|
||||
private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
treeMethod: String, numWorkers: Int): DataFrame = {
|
||||
device: String, numWorkers: Int): DataFrame = {
|
||||
val schema = new StructType(Array(
|
||||
StructField("sepal length", DoubleType, true),
|
||||
StructField("sepal width", DoubleType, true),
|
||||
@@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod,
|
||||
"device" -> device,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
|
||||
@@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
|
||||
test("Smoke test for SparkMLlibPipeline example") {
|
||||
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
|
||||
"target/pipeline-model", "auto", 2)
|
||||
"target/pipeline-model", "cpu", 2)
|
||||
}
|
||||
|
||||
test("Smoke test for SparkTraining example") {
|
||||
@@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
.config("spark.task.cpus", 1)
|
||||
.getOrCreate()
|
||||
|
||||
SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2)
|
||||
SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user