[BLOCKING] [jvm-packages] add gpu_hist and enable gpu scheduling (#5171)
* [jvm-packages] add gpu_hist tree method * change updater hist to grow_quantile_histmaker * add gpu scheduling * pass correct parameters to xgboost library * remove debug info * add use.cuda for pom * add CI for gpu_hist for jvm * add gpu unit tests * use gpu node to build jvm * use nvidia-docker * Add CLI interface to create_jni.py using argparse Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -31,8 +31,9 @@ object SparkMLlibPipeline {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
|
||||
if (args.length != 3) {
|
||||
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
|
||||
if (args.length != 3 && args.length != 4) {
|
||||
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path " +
|
||||
"[cpu|gpu]")
|
||||
sys.exit(1)
|
||||
}
|
||||
|
||||
@@ -40,6 +41,10 @@ object SparkMLlibPipeline {
|
||||
val nativeModelPath = args(1)
|
||||
val pipelineModelPath = args(2)
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
|
||||
val spark = SparkSession
|
||||
.builder()
|
||||
.appName("XGBoost4J-Spark Pipeline Example")
|
||||
@@ -76,7 +81,8 @@ object SparkMLlibPipeline {
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> 2
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod
|
||||
)
|
||||
)
|
||||
booster.setFeaturesCol("features")
|
||||
|
||||
@@ -28,9 +28,14 @@ object SparkTraining {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length < 1) {
|
||||
// scalastyle:off
|
||||
println("Usage: program input_path")
|
||||
println("Usage: program input_path [cpu|gpu]")
|
||||
sys.exit(1)
|
||||
}
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
val inputPath = args(0)
|
||||
val schema = new StructType(Array(
|
||||
@@ -68,7 +73,8 @@ object SparkTraining {
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> 2,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
|
||||
Reference in New Issue
Block a user