[BLOCKING] [jvm-packages] add gpu_hist and enable gpu scheduling (#5171)
* [jvm-packages] add gpu_hist tree method * change updater hist to grow_quantile_histmaker * add gpu scheduling * pass correct parameters to xgboost library * remove debug info * add use.cuda for pom * add CI for gpu_hist for jvm * add gpu unit tests * use gpu node to build jvm * use nvidia-docker * Add CLI interface to create_jni.py using argparse Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
6347fa1c2e
commit
8943eb4314
42
Jenkinsfile
vendored
42
Jenkinsfile
vendored
@ -75,6 +75,7 @@ pipeline {
|
|||||||
'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') },
|
'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') },
|
||||||
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') },
|
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') },
|
||||||
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') },
|
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') },
|
||||||
|
'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') },
|
||||||
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
|
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
|
||||||
'build-jvm-doc': { BuildJVMDoc() }
|
'build-jvm-doc': { BuildJVMDoc() }
|
||||||
])
|
])
|
||||||
@ -94,6 +95,7 @@ pipeline {
|
|||||||
'test-python-mgpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2', multi_gpu: true) },
|
'test-python-mgpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2', multi_gpu: true) },
|
||||||
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') },
|
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') },
|
||||||
'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
|
'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
|
||||||
|
'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') },
|
||||||
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') },
|
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') },
|
||||||
'test-jvm-jdk11': { CrossTestJVMwithJDK(jdk_version: '11') },
|
'test-jvm-jdk11': { CrossTestJVMwithJDK(jdk_version: '11') },
|
||||||
'test-jvm-jdk12': { CrossTestJVMwithJDK(jdk_version: '12') },
|
'test-jvm-jdk12': { CrossTestJVMwithJDK(jdk_version: '12') },
|
||||||
@ -282,6 +284,28 @@ def BuildCUDA(args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def BuildJVMPackagesWithCUDA(args) {
|
||||||
|
node('linux && gpu') {
|
||||||
|
unstash name: 'srcs'
|
||||||
|
echo "Build XGBoost4J-Spark with Spark ${args.spark_version}, CUDA ${args.cuda_version}"
|
||||||
|
def container_type = "jvm_gpu_build"
|
||||||
|
def docker_binary = "nvidia-docker"
|
||||||
|
def docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}"
|
||||||
|
def arch_flag = ""
|
||||||
|
if (env.BRANCH_NAME != 'master' && !(env.BRANCH_NAME.startsWith('release'))) {
|
||||||
|
arch_flag = "-DGPU_COMPUTE_VER=75"
|
||||||
|
}
|
||||||
|
// Use only 4 CPU cores
|
||||||
|
def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--cpuset-cpus 0-3'"
|
||||||
|
sh """
|
||||||
|
${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_jvm_packages.sh ${args.spark_version} -Duse.cuda=ON $arch_flag
|
||||||
|
"""
|
||||||
|
echo "Stashing XGBoost4J JAR with CUDA ${args.cuda_version} ..."
|
||||||
|
stash name: 'xgboost4j_jar_gpu', includes: "jvm-packages/xgboost4j/target/*.jar,jvm-packages/xgboost4j-spark/target/*.jar,jvm-packages/xgboost4j-example/target/*.jar"
|
||||||
|
deleteDir()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def BuildJVMPackages(args) {
|
def BuildJVMPackages(args) {
|
||||||
node('linux && cpu') {
|
node('linux && cpu') {
|
||||||
unstash name: 'srcs'
|
unstash name: 'srcs'
|
||||||
@ -386,6 +410,24 @@ def TestCppGPU(args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def CrossTestJVMwithJDKGPU(args) {
|
||||||
|
def nodeReq = 'linux && mgpu'
|
||||||
|
node(nodeReq) {
|
||||||
|
unstash name: "xgboost4j_jar_gpu"
|
||||||
|
unstash name: 'srcs'
|
||||||
|
if (args.spark_version != null) {
|
||||||
|
echo "Test XGBoost4J on a machine with JDK ${args.jdk_version}, Spark ${args.spark_version}, CUDA ${args.host_cuda_version}"
|
||||||
|
} else {
|
||||||
|
echo "Test XGBoost4J on a machine with JDK ${args.jdk_version}, CUDA ${args.host_cuda_version}"
|
||||||
|
}
|
||||||
|
def container_type = "gpu_jvm"
|
||||||
|
def docker_binary = "nvidia-docker"
|
||||||
|
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
|
||||||
|
sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_jvm_gpu_cross.sh"
|
||||||
|
deleteDir()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def CrossTestJVMwithJDK(args) {
|
def CrossTestJVMwithJDK(args) {
|
||||||
node('linux && cpu') {
|
node('linux && cpu') {
|
||||||
unstash name: 'xgboost4j_jar'
|
unstash name: 'xgboost4j_jar'
|
||||||
|
|||||||
@ -202,6 +202,14 @@ If you are on Mac OS and using a compiler that supports OpenMP, you need to go t
|
|||||||
|
|
||||||
in order to get the benefit of multi-threading.
|
in order to get the benefit of multi-threading.
|
||||||
|
|
||||||
|
Building with GPU support
|
||||||
|
-------------------------
|
||||||
|
If you want to build XGBoost4J that supports distributed GPU training, run
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
mvn -Duse.cuda=ON install
|
||||||
|
|
||||||
********
|
********
|
||||||
Contents
|
Contents
|
||||||
********
|
********
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import errno
|
import errno
|
||||||
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@ -7,7 +8,6 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
# Monkey-patch the API inconsistency between Python2.X and 3.X.
|
# Monkey-patch the API inconsistency between Python2.X and 3.X.
|
||||||
if sys.platform.startswith("linux"):
|
if sys.platform.startswith("linux"):
|
||||||
sys.platform = "linux"
|
sys.platform = "linux"
|
||||||
@ -20,6 +20,7 @@ CONFIG = {
|
|||||||
"USE_S3": "OFF",
|
"USE_S3": "OFF",
|
||||||
|
|
||||||
"USE_CUDA": "OFF",
|
"USE_CUDA": "OFF",
|
||||||
|
"USE_NCCL": "OFF",
|
||||||
"JVM_BINDINGS": "ON"
|
"JVM_BINDINGS": "ON"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,6 +69,10 @@ def normpath(path):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--use-cuda', type=str, choices=['ON', 'OFF'], default='OFF')
|
||||||
|
cli_args = parser.parse_args()
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
# Enable of your compiler supports OpenMP.
|
# Enable of your compiler supports OpenMP.
|
||||||
CONFIG["USE_OPENMP"] = "OFF"
|
CONFIG["USE_OPENMP"] = "OFF"
|
||||||
@ -88,12 +93,21 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
maybe_parallel_build = ""
|
maybe_parallel_build = ""
|
||||||
|
|
||||||
|
if cli_args.use_cuda == 'ON':
|
||||||
|
CONFIG['USE_CUDA'] = 'ON'
|
||||||
|
CONFIG['USE_NCCL'] = 'ON'
|
||||||
|
|
||||||
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
|
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
|
||||||
|
|
||||||
# if enviorment set rabit_mock
|
# if enviorment set rabit_mock
|
||||||
if os.getenv("RABIT_MOCK", None) is not None:
|
if os.getenv("RABIT_MOCK", None) is not None:
|
||||||
args.append("-DRABIT_MOCK:BOOL=ON")
|
args.append("-DRABIT_MOCK:BOOL=ON")
|
||||||
|
|
||||||
|
# if enviorment set GPU_ARCH_FLAG
|
||||||
|
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
|
||||||
|
if gpu_arch_flag is not None:
|
||||||
|
args.append("%s" % gpu_arch_flag)
|
||||||
|
|
||||||
run("cmake .. " + " ".join(args) + maybe_generator)
|
run("cmake .. " + " ".join(args) + maybe_generator)
|
||||||
run("cmake --build . --config Release" + maybe_parallel_build)
|
run("cmake --build . --config Release" + maybe_parallel_build)
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,7 @@
|
|||||||
<scala.version>2.12.8</scala.version>
|
<scala.version>2.12.8</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
<hadoop.version>2.7.3</hadoop.version>
|
<hadoop.version>2.7.3</hadoop.version>
|
||||||
|
<use.cuda>OFF</use.cuda>
|
||||||
</properties>
|
</properties>
|
||||||
<repositories>
|
<repositories>
|
||||||
<repository>
|
<repository>
|
||||||
@ -52,7 +53,65 @@
|
|||||||
<module>xgboost4j-spark</module>
|
<module>xgboost4j-spark</module>
|
||||||
<module>xgboost4j-flink</module>
|
<module>xgboost4j-flink</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<!-- default active profile excluding gpu related test suites -->
|
||||||
|
<id>default</id>
|
||||||
|
<activation>
|
||||||
|
<activeByDefault>true</activeByDefault>
|
||||||
|
</activation>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.scalatest</groupId>
|
||||||
|
<artifactId>scalatest-maven-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<tagsToExclude>ml.dmlc.xgboost4j.java.GpuTestSuite</tagsToExclude>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</profile>
|
||||||
|
|
||||||
|
<!-- gpu profile with both cpu and gpu test suites -->
|
||||||
|
<profile>
|
||||||
|
<id>gpu</id>
|
||||||
|
<activation>
|
||||||
|
<property>
|
||||||
|
<name>use.cuda</name>
|
||||||
|
<value>ON</value>
|
||||||
|
</property>
|
||||||
|
</activation>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.scalatest</groupId>
|
||||||
|
<artifactId>scalatest-maven-plugin</artifactId>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</profile>
|
||||||
|
|
||||||
|
<!-- gpu-with-gpu-tests profile with only gpu test suites -->
|
||||||
|
<profile>
|
||||||
|
<id>gpu-with-gpu-tests</id>
|
||||||
|
<properties>
|
||||||
|
<use.cuda>ON</use.cuda>
|
||||||
|
</properties>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.scalatest</groupId>
|
||||||
|
<artifactId>scalatest-maven-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<tagsToInclude>ml.dmlc.xgboost4j.java.GpuTestSuite</tagsToInclude>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>release</id>
|
<id>release</id>
|
||||||
<build>
|
<build>
|
||||||
@ -242,6 +301,25 @@
|
|||||||
<filtering>true</filtering>
|
<filtering>true</filtering>
|
||||||
</resource>
|
</resource>
|
||||||
</resources>
|
</resources>
|
||||||
|
|
||||||
|
<pluginManagement>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.scalatest</groupId>
|
||||||
|
<artifactId>scalatest-maven-plugin</artifactId>
|
||||||
|
<version>1.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>test</id>
|
||||||
|
<goals>
|
||||||
|
<goal>test</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</pluginManagement>
|
||||||
|
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.scalastyle</groupId>
|
<groupId>org.scalastyle</groupId>
|
||||||
@ -336,15 +414,6 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.scalatest</groupId>
|
<groupId>org.scalatest</groupId>
|
||||||
<artifactId>scalatest-maven-plugin</artifactId>
|
<artifactId>scalatest-maven-plugin</artifactId>
|
||||||
<version>1.0</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>test</id>
|
|
||||||
<goals>
|
|
||||||
<goal>test</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
<extensions>
|
<extensions>
|
||||||
|
|||||||
@ -31,8 +31,9 @@ object SparkMLlibPipeline {
|
|||||||
|
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
|
|
||||||
if (args.length != 3) {
|
if (args.length != 3 && args.length != 4) {
|
||||||
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
|
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path " +
|
||||||
|
"[cpu|gpu]")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +41,10 @@ object SparkMLlibPipeline {
|
|||||||
val nativeModelPath = args(1)
|
val nativeModelPath = args(1)
|
||||||
val pipelineModelPath = args(2)
|
val pipelineModelPath = args(2)
|
||||||
|
|
||||||
|
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||||
|
("gpu_hist", 1)
|
||||||
|
} else ("auto", 2)
|
||||||
|
|
||||||
val spark = SparkSession
|
val spark = SparkSession
|
||||||
.builder()
|
.builder()
|
||||||
.appName("XGBoost4J-Spark Pipeline Example")
|
.appName("XGBoost4J-Spark Pipeline Example")
|
||||||
@ -76,7 +81,8 @@ object SparkMLlibPipeline {
|
|||||||
"objective" -> "multi:softprob",
|
"objective" -> "multi:softprob",
|
||||||
"num_class" -> 3,
|
"num_class" -> 3,
|
||||||
"num_round" -> 100,
|
"num_round" -> 100,
|
||||||
"num_workers" -> 2
|
"num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
booster.setFeaturesCol("features")
|
booster.setFeaturesCol("features")
|
||||||
|
|||||||
@ -28,9 +28,14 @@ object SparkTraining {
|
|||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
if (args.length < 1) {
|
if (args.length < 1) {
|
||||||
// scalastyle:off
|
// scalastyle:off
|
||||||
println("Usage: program input_path")
|
println("Usage: program input_path [cpu|gpu]")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||||
|
("gpu_hist", 1)
|
||||||
|
} else ("auto", 2)
|
||||||
|
|
||||||
val spark = SparkSession.builder().getOrCreate()
|
val spark = SparkSession.builder().getOrCreate()
|
||||||
val inputPath = args(0)
|
val inputPath = args(0)
|
||||||
val schema = new StructType(Array(
|
val schema = new StructType(Array(
|
||||||
@ -68,7 +73,8 @@ object SparkTraining {
|
|||||||
"objective" -> "multi:softprob",
|
"objective" -> "multi:softprob",
|
||||||
"num_class" -> 3,
|
"num_class" -> 3,
|
||||||
"num_round" -> 100,
|
"num_round" -> 100,
|
||||||
"num_workers" -> 2,
|
"num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod,
|
||||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||||
setFeaturesCol("features").
|
setFeaturesCol("features").
|
||||||
|
|||||||
@ -22,7 +22,6 @@ import java.nio.file.Files
|
|||||||
import scala.collection.{AbstractIterator, mutable}
|
import scala.collection.{AbstractIterator, mutable}
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||||
@ -32,7 +31,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|||||||
import org.apache.commons.io.FileUtils
|
import org.apache.commons.io.FileUtils
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.FileSystem
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
@ -76,7 +74,9 @@ private[this] case class XGBoostExecutionParams(
|
|||||||
checkpointParam: Option[ExternalCheckpointParams],
|
checkpointParam: Option[ExternalCheckpointParams],
|
||||||
xgbInputParams: XGBoostExecutionInputParams,
|
xgbInputParams: XGBoostExecutionInputParams,
|
||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
cacheTrainingSet: Boolean) {
|
cacheTrainingSet: Boolean,
|
||||||
|
treeMethod: Option[String],
|
||||||
|
isLocal: Boolean) {
|
||||||
|
|
||||||
private var rawParamMap: Map[String, Any] = _
|
private var rawParamMap: Map[String, Any] = _
|
||||||
|
|
||||||
@ -93,6 +93,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
|
private val isLocal = sc.isLocal
|
||||||
|
|
||||||
private val overridedParams = overrideParams(rawParams, sc)
|
private val overridedParams = overrideParams(rawParams, sc)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -168,11 +170,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
.getOrElse("allow_non_zero_for_missing", false)
|
.getOrElse("allow_non_zero_for_missing", false)
|
||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
validateSparkSslConf
|
validateSparkSslConf
|
||||||
|
var treeMethod: Option[String] = None
|
||||||
if (overridedParams.contains("tree_method")) {
|
if (overridedParams.contains("tree_method")) {
|
||||||
require(overridedParams("tree_method") == "hist" ||
|
require(overridedParams("tree_method") == "hist" ||
|
||||||
overridedParams("tree_method") == "approx" ||
|
overridedParams("tree_method") == "approx" ||
|
||||||
overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" +
|
overridedParams("tree_method") == "auto" ||
|
||||||
" 'hist', 'approx' and 'auto'")
|
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
|
||||||
|
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||||
|
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||||
}
|
}
|
||||||
if (overridedParams.contains("train_test_ratio")) {
|
if (overridedParams.contains("train_test_ratio")) {
|
||||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||||
@ -221,7 +226,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
checkpointParam,
|
checkpointParam,
|
||||||
inputParams,
|
inputParams,
|
||||||
xgbExecEarlyStoppingParams,
|
xgbExecEarlyStoppingParams,
|
||||||
cacheTrainingSet)
|
cacheTrainingSet,
|
||||||
|
treeMethod,
|
||||||
|
isLocal)
|
||||||
xgbExecParam.setRawParamMap(overridedParams)
|
xgbExecParam.setRawParamMap(overridedParams)
|
||||||
xgbExecParam
|
xgbExecParam
|
||||||
}
|
}
|
||||||
@ -335,6 +342,26 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def getGPUAddrFromResources: Int = {
|
||||||
|
val tc = TaskContext.get()
|
||||||
|
if (tc == null) {
|
||||||
|
throw new RuntimeException("Something wrong for task context")
|
||||||
|
}
|
||||||
|
val resources = tc.resources()
|
||||||
|
if (resources.contains("gpu")) {
|
||||||
|
val addrs = resources("gpu").addresses
|
||||||
|
if (addrs.size > 1) {
|
||||||
|
// TODO should we throw exception ?
|
||||||
|
logger.warn("XGBoost only supports 1 gpu per worker")
|
||||||
|
}
|
||||||
|
// take the first one
|
||||||
|
addrs.head.toInt
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("gpu is not allocated by spark, " +
|
||||||
|
"please check if gpu scheduling is enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private def buildDistributedBooster(
|
private def buildDistributedBooster(
|
||||||
watches: Watches,
|
watches: Watches,
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
@ -362,13 +389,25 @@ object XGBoost extends Serializable {
|
|||||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||||
|
|
||||||
|
var params = xgbExecutionParam.toMap
|
||||||
|
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
|
||||||
|
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||||
|
// For local mode, force gpu id to primary device
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
getGPUAddrFromResources
|
||||||
|
}
|
||||||
|
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||||
|
params = params + ("gpu_id" -> gpuId)
|
||||||
|
}
|
||||||
val booster = if (makeCheckpoint) {
|
val booster = if (makeCheckpoint) {
|
||||||
SXGBoost.trainAndSaveCheckpoint(
|
SXGBoost.trainAndSaveCheckpoint(
|
||||||
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
watches.toMap("train"), params, numRounds,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
||||||
} else {
|
} else {
|
||||||
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
SXGBoost.train(watches.toMap("train"), params, numRounds,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -145,11 +145,12 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
final def getAlpha: Double = $(alpha)
|
final def getAlpha: Double = $(alpha)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
|
* The tree construction algorithm used in XGBoost. options:
|
||||||
* [default='auto']
|
* {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
|
||||||
*/
|
*/
|
||||||
final val treeMethod = new Param[String](this, "treeMethod",
|
final val treeMethod = new Param[String](this, "treeMethod",
|
||||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
|
"The tree construction algorithm used in XGBoost, options: " +
|
||||||
|
"{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
|
||||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||||
|
|
||||||
final def getTreeMethod: String = $(treeMethod)
|
final def getTreeMethod: String = $(treeMethod)
|
||||||
@ -292,7 +293,7 @@ private[spark] object BoosterParams {
|
|||||||
|
|
||||||
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
||||||
|
|
||||||
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
|
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
|
||||||
|
|
||||||
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
|
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
|
||||||
|
|
||||||
|
|||||||
@ -261,10 +261,10 @@ private[spark] trait ParamMapFuncs extends Params {
|
|||||||
for ((paramName, paramValue) <- xgboostParams) {
|
for ((paramName, paramValue) <- xgboostParams) {
|
||||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||||
paramValue != "hist")) {
|
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
||||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||||
s" XGBoost-Spark only supports gbtree as booster type" +
|
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
|
||||||
" and grow_histmaker,prune or hist as the updater type")
|
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
||||||
}
|
}
|
||||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||||
params.find(_.name == name).foreach {
|
params.find(_.name == name).foreach {
|
||||||
|
|||||||
@ -16,82 +16,16 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.GpuTestSuite
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
import org.apache.spark.Partitioner
|
import org.apache.spark.Partitioner
|
||||||
|
|
||||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
|
protected val treeMethod: String = "auto"
|
||||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
|
||||||
val testDF = buildDataFrame(Classification.test)
|
|
||||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("XGBoostClassifier should make correct predictions after upstream random sort") {
|
|
||||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
val trainingDF = buildDataFrameWithRandSort(Classification.train)
|
|
||||||
val testDF = buildDataFrameWithRandSort(Classification.test)
|
|
||||||
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def checkResultsWithXGBoost4j(
|
|
||||||
trainingDM: DMatrix,
|
|
||||||
testDM: DMatrix,
|
|
||||||
trainingDF: DataFrame,
|
|
||||||
testDF: DataFrame,
|
|
||||||
round: Int = 5): Unit = {
|
|
||||||
val paramMap = Map(
|
|
||||||
"eta" -> "1",
|
|
||||||
"max_depth" -> "6",
|
|
||||||
"silent" -> "1",
|
|
||||||
"objective" -> "binary:logistic")
|
|
||||||
|
|
||||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
|
||||||
val prediction1 = model1.predict(testDM)
|
|
||||||
|
|
||||||
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
|
|
||||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).
|
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
|
||||||
|
|
||||||
assert(testDF.count() === prediction2.size)
|
|
||||||
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
|
||||||
for (i <- prediction1.indices) {
|
|
||||||
assert(prediction1(i).length === prediction2(i).values.length - 1)
|
|
||||||
for (j <- prediction1(i).indices) {
|
|
||||||
assert(prediction1(i)(j) === prediction2(i)(j + 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
|
||||||
val prediction4 = model2.transform(testDF).
|
|
||||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
|
||||||
|
|
||||||
assert(testDF.count() === prediction4.size)
|
|
||||||
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
|
|
||||||
for (i <- prediction3.indices) {
|
|
||||||
assert(prediction3(i).length === prediction4(i).values.length - 1)
|
|
||||||
for (j <- prediction3(i).indices) {
|
|
||||||
assert(prediction3(i)(j) === prediction4(i)(j + 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check the equality of single instance prediction
|
|
||||||
val firstOfDM = testDM.slice(Array(0))
|
|
||||||
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
|
|
||||||
.head()
|
|
||||||
.getAs[Vector]("features")
|
|
||||||
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
|
||||||
val prediction6 = model2.predict(firstOfDF)
|
|
||||||
assert(prediction5 === prediction6)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
@ -104,6 +38,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
"silent" -> "1",
|
"silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
"num_round" -> round,
|
"num_round" -> round,
|
||||||
|
"tree_method" -> treeMethod,
|
||||||
"num_workers" -> numWorkers)
|
"num_workers" -> numWorkers)
|
||||||
|
|
||||||
// Set params in XGBoost way
|
// Set params in XGBoost way
|
||||||
@ -128,7 +63,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("test schema of XGBoostClassificationModel") {
|
test("test schema of XGBoostClassificationModel") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
val trainingDF = buildDataFrame(Classification.train)
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
val testDF = buildDataFrame(Classification.test)
|
val testDF = buildDataFrame(Classification.test)
|
||||||
|
|
||||||
@ -160,7 +96,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("multi class classification") {
|
test("multi class classification") {
|
||||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||||
"num_workers" -> numWorkers)
|
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
val model = xgb.fit(trainingDF)
|
val model = xgb.fit(trainingDF)
|
||||||
@ -175,7 +111,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
|
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
val model1 = xgb.fit(training1)
|
val model1 = xgb.fit(training1)
|
||||||
@ -194,7 +130,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("test predictionLeaf") {
|
test("test predictionLeaf") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val groundTruth = test.count()
|
val groundTruth = test.count()
|
||||||
@ -209,7 +145,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("test predictionLeaf with empty column name") {
|
test("test predictionLeaf with empty column name") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
@ -222,7 +158,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("test predictionContrib") {
|
test("test predictionContrib") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val groundTruth = test.count()
|
val groundTruth = test.count()
|
||||||
@ -237,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("test predictionContrib with empty column name") {
|
test("test predictionContrib with empty column name") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
@ -250,7 +186,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
test("test predictionLeaf and predictionContrib") {
|
test("test predictionLeaf and predictionContrib") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val test = buildDataFrame(Classification.test)
|
val test = buildDataFrame(Classification.test)
|
||||||
val groundTruth = test.count()
|
val groundTruth = test.count()
|
||||||
@ -264,6 +200,80 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
assert(resultDF.columns.contains("predictContrib"))
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
|
||||||
|
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
|
||||||
|
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
val trainingDF = buildDataFrame(Classification.train)
|
||||||
|
val testDF = buildDataFrame(Classification.test)
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("XGBoostClassifier should make correct predictions after upstream random sort") {
|
||||||
|
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
val trainingDF = buildDataFrameWithRandSort(Classification.train)
|
||||||
|
val testDF = buildDataFrameWithRandSort(Classification.test)
|
||||||
|
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkResultsWithXGBoost4j(
|
||||||
|
trainingDM: DMatrix,
|
||||||
|
testDM: DMatrix,
|
||||||
|
trainingDF: DataFrame,
|
||||||
|
testDF: DataFrame,
|
||||||
|
round: Int = 5): Unit = {
|
||||||
|
val paramMap = Map(
|
||||||
|
"eta" -> "1",
|
||||||
|
"max_depth" -> "6",
|
||||||
|
"silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic",
|
||||||
|
"tree_method" -> treeMethod,
|
||||||
|
"max_bin" -> 16)
|
||||||
|
|
||||||
|
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||||
|
val prediction1 = model1.predict(testDM)
|
||||||
|
|
||||||
|
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
|
||||||
|
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||||
|
|
||||||
|
val prediction2 = model2.transform(testDF).
|
||||||
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
||||||
|
|
||||||
|
assert(testDF.count() === prediction2.size)
|
||||||
|
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
||||||
|
for (i <- prediction1.indices) {
|
||||||
|
assert(prediction1(i).length === prediction2(i).values.length - 1)
|
||||||
|
for (j <- prediction1(i).indices) {
|
||||||
|
assert(prediction1(i)(j) === prediction2(i)(j + 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
||||||
|
val prediction4 = model2.transform(testDF).
|
||||||
|
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
||||||
|
|
||||||
|
assert(testDF.count() === prediction4.size)
|
||||||
|
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
|
||||||
|
for (i <- prediction3.indices) {
|
||||||
|
assert(prediction3(i).length === prediction4(i).values.length - 1)
|
||||||
|
for (j <- prediction3(i).indices) {
|
||||||
|
assert(prediction3(i)(j) === prediction4(i)(j + 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check the equality of single instance prediction
|
||||||
|
val firstOfDM = testDM.slice(Array(0))
|
||||||
|
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
|
||||||
|
.head()
|
||||||
|
.getAs[Vector]("features")
|
||||||
|
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
|
||||||
|
val prediction6 = model2.predict(firstOfDF)
|
||||||
|
assert(prediction5 === prediction6)
|
||||||
|
}
|
||||||
|
|
||||||
test("infrequent features") {
|
test("infrequent features") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic",
|
"objective" -> "binary:logistic",
|
||||||
@ -305,5 +315,10 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
val xgb = new XGBoostClassifier(paramMap)
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
xgb.fit(repartitioned)
|
xgb.fit(repartitioned)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@GpuTestSuite
|
||||||
|
class XGBoostGpuClassifierSuite extends XGBoostClassifierSuiteBase {
|
||||||
|
override protected val treeMethod: String = "gpu_hist"
|
||||||
|
override protected val numWorkers: Int = 1
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.GpuTestSuite
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
@ -23,7 +24,8 @@ import org.apache.spark.sql.{DataFrame, Row}
|
|||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
|
||||||
|
protected val treeMethod: String = "auto"
|
||||||
|
|
||||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||||
val trainingDM = new DMatrix(Regression.train.iterator)
|
val trainingDM = new DMatrix(Regression.train.iterator)
|
||||||
@ -51,7 +53,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
"eta" -> "1",
|
"eta" -> "1",
|
||||||
"max_depth" -> "6",
|
"max_depth" -> "6",
|
||||||
"silent" -> "1",
|
"silent" -> "1",
|
||||||
"objective" -> "reg:squarederror")
|
"objective" -> "reg:squarederror",
|
||||||
|
"max_bin" -> 16,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
|
|
||||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||||
val prediction1 = model1.predict(testDM)
|
val prediction1 = model1.predict(testDM)
|
||||||
@ -88,6 +92,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
"silent" -> "1",
|
"silent" -> "1",
|
||||||
"objective" -> "reg:squarederror",
|
"objective" -> "reg:squarederror",
|
||||||
"num_round" -> round,
|
"num_round" -> round,
|
||||||
|
"tree_method" -> treeMethod,
|
||||||
"num_workers" -> numWorkers)
|
"num_workers" -> numWorkers)
|
||||||
|
|
||||||
// Set params in XGBoost way
|
// Set params in XGBoost way
|
||||||
@ -99,6 +104,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
.setSilent(1)
|
.setSilent(1)
|
||||||
.setObjective("reg:squarederror")
|
.setObjective("reg:squarederror")
|
||||||
.setNumRound(round)
|
.setNumRound(round)
|
||||||
|
.setTreeMethod(treeMethod)
|
||||||
.setNumWorkers(numWorkers)
|
.setNumWorkers(numWorkers)
|
||||||
.fit(trainingDF)
|
.fit(trainingDF)
|
||||||
|
|
||||||
@ -113,7 +119,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
test("ranking: use group data") {
|
test("ranking: use group data") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
|
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
|
||||||
"group_col" -> "group")
|
"group_col" -> "group", "tree_method" -> treeMethod)
|
||||||
|
|
||||||
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||||
val testDF = buildDataFrame(Ranking.test)
|
val testDF = buildDataFrame(Ranking.test)
|
||||||
@ -125,7 +131,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("use weight") {
|
test("use weight") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
|
|
||||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
|
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
|
||||||
val trainingDF = buildDataFrame(Regression.train)
|
val trainingDF = buildDataFrame(Regression.train)
|
||||||
@ -140,7 +147,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("test predictionLeaf") {
|
test("test predictionLeaf") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Regression.train)
|
val training = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
val groundTruth = testDF.count()
|
val groundTruth = testDF.count()
|
||||||
@ -154,7 +162,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("test predictionLeaf with empty column name") {
|
test("test predictionLeaf with empty column name") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Regression.train)
|
val training = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
@ -166,7 +175,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("test predictionContrib") {
|
test("test predictionContrib") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Regression.train)
|
val training = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
val groundTruth = testDF.count()
|
val groundTruth = testDF.count()
|
||||||
@ -180,7 +190,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("test predictionContrib with empty column name") {
|
test("test predictionContrib with empty column name") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Regression.train)
|
val training = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
val xgb = new XGBoostRegressor(paramMap)
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
@ -192,7 +203,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("test predictionLeaf and predictionContrib") {
|
test("test predictionLeaf and predictionContrib") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||||
|
"tree_method" -> treeMethod)
|
||||||
val training = buildDataFrame(Regression.train)
|
val training = buildDataFrame(Regression.train)
|
||||||
val testDF = buildDataFrame(Regression.test)
|
val testDF = buildDataFrame(Regression.test)
|
||||||
val groundTruth = testDF.count()
|
val groundTruth = testDF.count()
|
||||||
@ -206,3 +218,13 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
assert(resultDF.columns.contains("predictContrib"))
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@GpuTestSuite
|
||||||
|
class XGBoostGpuRegressorSuite extends XGBoostRegressorSuiteBase {
|
||||||
|
override protected val treeMethod: String = "gpu_hist"
|
||||||
|
override protected val numWorkers: Int = 1
|
||||||
|
}
|
||||||
|
|||||||
@ -43,6 +43,12 @@
|
|||||||
<version>2.5.23</version>
|
<version>2.5.23</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.scalatest</groupId>
|
||||||
|
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||||
|
<version>3.0.5</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
@ -78,6 +84,8 @@
|
|||||||
<executable>python</executable>
|
<executable>python</executable>
|
||||||
<arguments>
|
<arguments>
|
||||||
<argument>create_jni.py</argument>
|
<argument>create_jni.py</argument>
|
||||||
|
<argument>--use-cuda</argument>
|
||||||
|
<argument>${use.cuda}</argument>
|
||||||
</arguments>
|
</arguments>
|
||||||
<workingDirectory>${user.dir}</workingDirectory>
|
<workingDirectory>${user.dir}</workingDirectory>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
|||||||
@ -0,0 +1,28 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2020 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
import org.scalatest.TagAnnotation;
|
||||||
|
|
||||||
|
@TagAnnotation
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
@Target({ElementType.METHOD, ElementType.TYPE})
|
||||||
|
public @interface GpuTestSuite {}
|
||||||
@ -46,6 +46,7 @@ object XGBoost {
|
|||||||
} else {
|
} else {
|
||||||
prevBooster.booster
|
prevBooster.booster
|
||||||
}
|
}
|
||||||
|
|
||||||
val xgboostInJava = checkpointParams.
|
val xgboostInJava = checkpointParams.
|
||||||
map(cp => {
|
map(cp => {
|
||||||
JXGBoost.trainAndSaveCheckpoint(
|
JXGBoost.trainAndSaveCheckpoint(
|
||||||
|
|||||||
51
tests/ci_build/Dockerfile.gpu_jvm
Normal file
51
tests/ci_build/Dockerfile.gpu_jvm
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
ARG CUDA_VERSION
|
||||||
|
FROM nvidia/cuda:$CUDA_VERSION-runtime-ubuntu16.04
|
||||||
|
ARG JDK_VERSION=8
|
||||||
|
ARG SPARK_VERSION=3.0.0
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
ENV DEBIAN_FRONTEND noninteractive
|
||||||
|
|
||||||
|
# Install all basic requirements
|
||||||
|
RUN \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y software-properties-common && \
|
||||||
|
add-apt-repository ppa:openjdk-r/ppa && \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y tar unzip wget openjdk-$JDK_VERSION-jdk libgomp1 && \
|
||||||
|
# Python
|
||||||
|
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||||
|
bash Miniconda3.sh -b -p /opt/python && \
|
||||||
|
/opt/python/bin/pip install awscli && \
|
||||||
|
# Maven
|
||||||
|
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
|
||||||
|
tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \
|
||||||
|
ln -s /opt/apache-maven-3.6.1/ /opt/maven && \
|
||||||
|
# Spark
|
||||||
|
wget https://archive.apache.org/dist/spark/spark-$SPARK_VERSION/spark-$SPARK_VERSION-bin-hadoop2.7.tgz && \
|
||||||
|
tar xvf spark-$SPARK_VERSION-bin-hadoop2.7.tgz -C /opt && \
|
||||||
|
ln -s /opt/spark-$SPARK_VERSION-bin-hadoop2.7 /opt/spark
|
||||||
|
|
||||||
|
ENV PATH=/opt/python/bin:/opt/spark/bin:/opt/maven/bin:$PATH
|
||||||
|
|
||||||
|
# Install Python packages
|
||||||
|
RUN \
|
||||||
|
pip install numpy scipy pandas scikit-learn
|
||||||
|
|
||||||
|
ENV GOSU_VERSION 1.10
|
||||||
|
|
||||||
|
# Install lightweight sudo (not bound to TTY)
|
||||||
|
RUN set -ex; \
|
||||||
|
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
|
||||||
|
chmod +x /usr/local/bin/gosu && \
|
||||||
|
gosu nobody true
|
||||||
|
|
||||||
|
# Set default JDK version
|
||||||
|
RUN update-java-alternatives -v -s java-1.$JDK_VERSION.0-openjdk-amd64
|
||||||
|
|
||||||
|
# Default entry-point to use if running locally
|
||||||
|
# It will preserve attributes of created files
|
||||||
|
COPY entrypoint.sh /scripts/
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
ENTRYPOINT ["/scripts/entrypoint.sh"]
|
||||||
63
tests/ci_build/Dockerfile.jvm_gpu_build
Normal file
63
tests/ci_build/Dockerfile.jvm_gpu_build
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
ARG CUDA_VERSION
|
||||||
|
FROM nvidia/cuda:$CUDA_VERSION-devel-centos6
|
||||||
|
ARG CUDA_VERSION
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
ENV DEBIAN_FRONTEND noninteractive
|
||||||
|
ENV DEVTOOLSET_URL_ROOT http://vault.centos.org/6.9/sclo/x86_64/rh/devtoolset-4/
|
||||||
|
|
||||||
|
# Install all basic requirements
|
||||||
|
RUN \
|
||||||
|
yum -y update && \
|
||||||
|
yum install -y tar unzip wget xz git centos-release-scl yum-utils java-1.8.0-openjdk-devel && \
|
||||||
|
yum-config-manager --enable centos-sclo-rh-testing && \
|
||||||
|
yum -y update && \
|
||||||
|
yum install -y $DEVTOOLSET_URL_ROOT/devtoolset-4-gcc-5.3.1-6.1.el6.x86_64.rpm \
|
||||||
|
$DEVTOOLSET_URL_ROOT/devtoolset-4-gcc-c++-5.3.1-6.1.el6.x86_64.rpm \
|
||||||
|
$DEVTOOLSET_URL_ROOT/devtoolset-4-binutils-2.25.1-8.el6.x86_64.rpm \
|
||||||
|
$DEVTOOLSET_URL_ROOT/devtoolset-4-runtime-4.1-3.sc1.el6.x86_64.rpm \
|
||||||
|
$DEVTOOLSET_URL_ROOT/devtoolset-4-libstdc++-devel-5.3.1-6.1.el6.x86_64.rpm && \
|
||||||
|
# Python
|
||||||
|
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||||
|
bash Miniconda3.sh -b -p /opt/python && \
|
||||||
|
# CMake
|
||||||
|
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
|
||||||
|
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
|
||||||
|
# Maven
|
||||||
|
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
|
||||||
|
tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \
|
||||||
|
ln -s /opt/apache-maven-3.6.1/ /opt/maven
|
||||||
|
|
||||||
|
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
|
||||||
|
RUN \
|
||||||
|
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
|
||||||
|
export NCCL_VERSION=2.4.8-1 && \
|
||||||
|
wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||||
|
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||||
|
yum -y update && \
|
||||||
|
yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \
|
||||||
|
rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm;
|
||||||
|
|
||||||
|
ENV PATH=/opt/python/bin:/opt/maven/bin:$PATH
|
||||||
|
ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc
|
||||||
|
ENV CXX=/opt/rh/devtoolset-4/root/usr/bin/c++
|
||||||
|
ENV CPP=/opt/rh/devtoolset-4/root/usr/bin/cpp
|
||||||
|
|
||||||
|
# Install Python packages
|
||||||
|
RUN \
|
||||||
|
pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22 awscli
|
||||||
|
|
||||||
|
ENV GOSU_VERSION 1.10
|
||||||
|
|
||||||
|
# Install lightweight sudo (not bound to TTY)
|
||||||
|
RUN set -ex; \
|
||||||
|
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
|
||||||
|
chmod +x /usr/local/bin/gosu && \
|
||||||
|
gosu nobody true
|
||||||
|
|
||||||
|
# Default entry-point to use if running locally
|
||||||
|
# It will preserve attributes of created files
|
||||||
|
COPY entrypoint.sh /scripts/
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
ENTRYPOINT ["/scripts/entrypoint.sh"]
|
||||||
@ -3,12 +3,15 @@
|
|||||||
set -e
|
set -e
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
if [ $# -ne 1 ]; then
|
|
||||||
echo "Usage: $0 [spark version]"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
spark_version=$1
|
spark_version=$1
|
||||||
|
use_cuda=$2
|
||||||
|
gpu_arch=$3
|
||||||
|
|
||||||
|
gpu_options=""
|
||||||
|
if [ "x$use_cuda" == "x-Duse.cuda=ON" ]; then
|
||||||
|
# Since building jvm for CPU will do unit tests, choose gpu-with-gpu-tests profile to build
|
||||||
|
gpu_options=" -Pgpu-with-gpu-tests "
|
||||||
|
fi
|
||||||
|
|
||||||
# Initialize local Maven repository
|
# Initialize local Maven repository
|
||||||
./tests/ci_build/initialize_maven.sh
|
./tests/ci_build/initialize_maven.sh
|
||||||
@ -16,7 +19,11 @@ spark_version=$1
|
|||||||
rm -rf build/
|
rm -rf build/
|
||||||
cd jvm-packages
|
cd jvm-packages
|
||||||
export RABIT_MOCK=ON
|
export RABIT_MOCK=ON
|
||||||
mvn --no-transfer-progress package -Dspark.version=${spark_version}
|
|
||||||
|
if [ "x$gpu_arch" != "x" ]; then
|
||||||
|
export GPU_ARCH_FLAG=$gpu_arch
|
||||||
|
fi
|
||||||
|
mvn --no-transfer-progress package -Dspark.version=${spark_version} $gpu_options
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
set +e
|
set +e
|
||||||
|
|||||||
40
tests/ci_build/test_jvm_gpu_cross.sh
Executable file
40
tests/ci_build/test_jvm_gpu_cross.sh
Executable file
@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -x
|
||||||
|
|
||||||
|
|
||||||
|
nvidia-smi
|
||||||
|
|
||||||
|
ls /usr/local/
|
||||||
|
|
||||||
|
# Initialize local Maven repository
|
||||||
|
./tests/ci_build/initialize_maven.sh
|
||||||
|
|
||||||
|
# Get version number of XGBoost4J and other auxiliary information
|
||||||
|
cd jvm-packages
|
||||||
|
xgboost4j_version=$(mvn help:evaluate -Dexpression=project.version -q -DforceStdout)
|
||||||
|
scala_binary_version=$(mvn help:evaluate -Dexpression=scala.binary.version -q -DforceStdout)
|
||||||
|
|
||||||
|
python3 xgboost4j-tester/get_iris.py
|
||||||
|
xgb_jars="./xgboost4j/target/xgboost4j_${scala_binary_version}-${xgboost4j_version}.jar,./xgboost4j-spark/target/xgboost4j-spark_${scala_binary_version}-${xgboost4j_version}.jar"
|
||||||
|
example_jar="./xgboost4j-example/target/xgboost4j-example_${scala_binary_version}-${xgboost4j_version}.jar"
|
||||||
|
|
||||||
|
echo "Run SparkTraining locally ... "
|
||||||
|
spark-submit \
|
||||||
|
--master 'local[1]' \
|
||||||
|
--class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining \
|
||||||
|
--jars $xgb_jars \
|
||||||
|
$example_jar \
|
||||||
|
${PWD}/iris.csv gpu \
|
||||||
|
|
||||||
|
echo "Run SparkMLlibPipeline locally ... "
|
||||||
|
spark-submit \
|
||||||
|
--master 'local[1]' \
|
||||||
|
--class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline \
|
||||||
|
--jars $xgb_jars \
|
||||||
|
$example_jar \
|
||||||
|
${PWD}/iris.csv ${PWD}/native_model ${PWD}/pipeline_model gpu \
|
||||||
|
|
||||||
|
set +x
|
||||||
|
set +e
|
||||||
Loading…
x
Reference in New Issue
Block a user