From 8943eb43145d2e1787ef7b9e3c25fb921a56e27c Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 27 Jul 2020 12:53:24 +0800 Subject: [PATCH] [BLOCKING] [jvm-packages] add gpu_hist and enable gpu scheduling (#5171) * [jvm-packages] add gpu_hist tree method * change updater hist to grow_quantile_histmaker * add gpu scheduling * pass correct parameters to xgboost library * remove debug info * add use.cuda for pom * add CI for gpu_hist for jvm * add gpu unit tests * use gpu node to build jvm * use nvidia-docker * Add CLI interface to create_jni.py using argparse Co-authored-by: Hyunsu Cho --- Jenkinsfile | 42 +++++ doc/jvm/index.rst | 8 + jvm-packages/create_jni.py | 16 +- jvm-packages/pom.xml | 87 ++++++++- .../example/spark/SparkMLlibPipeline.scala | 12 +- .../scala/example/spark/SparkTraining.scala | 10 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 55 +++++- .../scala/spark/params/BoosterParams.scala | 9 +- .../scala/spark/params/GeneralParams.scala | 6 +- .../scala/spark/XGBoostClassifierSuite.scala | 171 ++++++++++-------- .../scala/spark/XGBoostRegressorSuite.scala | 40 +++- jvm-packages/xgboost4j/pom.xml | 8 + .../ml/dmlc/xgboost4j/java/GpuTestSuite.java | 28 +++ .../ml/dmlc/xgboost4j/scala/XGBoost.scala | 1 + tests/ci_build/Dockerfile.gpu_jvm | 51 ++++++ tests/ci_build/Dockerfile.jvm_gpu_build | 63 +++++++ tests/ci_build/build_jvm_packages.sh | 19 +- tests/ci_build/test_jvm_gpu_cross.sh | 40 ++++ 18 files changed, 543 insertions(+), 123 deletions(-) create mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/GpuTestSuite.java create mode 100644 tests/ci_build/Dockerfile.gpu_jvm create mode 100644 tests/ci_build/Dockerfile.jvm_gpu_build create mode 100755 tests/ci_build/test_jvm_gpu_cross.sh diff --git a/Jenkinsfile b/Jenkinsfile index b30cd6ac9..d8309f1fb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -75,6 +75,7 @@ pipeline { 'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') }, 'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') }, 'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') }, + 'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') }, 'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') }, 'build-jvm-doc': { BuildJVMDoc() } ]) @@ -94,6 +95,7 @@ pipeline { 'test-python-mgpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2', multi_gpu: true) }, 'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') }, 'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') }, + 'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') }, 'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') }, 'test-jvm-jdk11': { CrossTestJVMwithJDK(jdk_version: '11') }, 'test-jvm-jdk12': { CrossTestJVMwithJDK(jdk_version: '12') }, @@ -282,6 +284,28 @@ def BuildCUDA(args) { } } +def BuildJVMPackagesWithCUDA(args) { + node('linux && gpu') { + unstash name: 'srcs' + echo "Build XGBoost4J-Spark with Spark ${args.spark_version}, CUDA ${args.cuda_version}" + def container_type = "jvm_gpu_build" + def docker_binary = "nvidia-docker" + def docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}" + def arch_flag = "" + if (env.BRANCH_NAME != 'master' && !(env.BRANCH_NAME.startsWith('release'))) { + arch_flag = "-DGPU_COMPUTE_VER=75" + } + // Use only 4 CPU cores + def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--cpuset-cpus 0-3'" + sh """ + ${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_jvm_packages.sh ${args.spark_version} -Duse.cuda=ON $arch_flag + """ + echo "Stashing XGBoost4J JAR with CUDA ${args.cuda_version} ..." + stash name: 'xgboost4j_jar_gpu', includes: "jvm-packages/xgboost4j/target/*.jar,jvm-packages/xgboost4j-spark/target/*.jar,jvm-packages/xgboost4j-example/target/*.jar" + deleteDir() + } +} + def BuildJVMPackages(args) { node('linux && cpu') { unstash name: 'srcs' @@ -386,6 +410,24 @@ def TestCppGPU(args) { } } +def CrossTestJVMwithJDKGPU(args) { + def nodeReq = 'linux && mgpu' + node(nodeReq) { + unstash name: "xgboost4j_jar_gpu" + unstash name: 'srcs' + if (args.spark_version != null) { + echo "Test XGBoost4J on a machine with JDK ${args.jdk_version}, Spark ${args.spark_version}, CUDA ${args.host_cuda_version}" + } else { + echo "Test XGBoost4J on a machine with JDK ${args.jdk_version}, CUDA ${args.host_cuda_version}" + } + def container_type = "gpu_jvm" + def docker_binary = "nvidia-docker" + def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}" + sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_jvm_gpu_cross.sh" + deleteDir() + } +} + def CrossTestJVMwithJDK(args) { node('linux && cpu') { unstash name: 'xgboost4j_jar' diff --git a/doc/jvm/index.rst b/doc/jvm/index.rst index 69e4569c3..a4c9cdd53 100644 --- a/doc/jvm/index.rst +++ b/doc/jvm/index.rst @@ -202,6 +202,14 @@ If you are on Mac OS and using a compiler that supports OpenMP, you need to go t in order to get the benefit of multi-threading. +Building with GPU support +------------------------- +If you want to build XGBoost4J that supports distributed GPU training, run + +.. code-block:: bash + + mvn -Duse.cuda=ON install + ******** Contents ******** diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index a30886f65..29169c97b 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import errno +import argparse import glob import os import shutil @@ -7,7 +8,6 @@ import subprocess import sys from contextlib import contextmanager - # Monkey-patch the API inconsistency between Python2.X and 3.X. if sys.platform.startswith("linux"): sys.platform = "linux" @@ -20,6 +20,7 @@ CONFIG = { "USE_S3": "OFF", "USE_CUDA": "OFF", + "USE_NCCL": "OFF", "JVM_BINDINGS": "ON" } @@ -68,6 +69,10 @@ def normpath(path): if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--use-cuda', type=str, choices=['ON', 'OFF'], default='OFF') + cli_args = parser.parse_args() + if sys.platform == "darwin": # Enable of your compiler supports OpenMP. CONFIG["USE_OPENMP"] = "OFF" @@ -88,12 +93,21 @@ if __name__ == "__main__": else: maybe_parallel_build = "" + if cli_args.use_cuda == 'ON': + CONFIG['USE_CUDA'] = 'ON' + CONFIG['USE_NCCL'] = 'ON' + args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()] # if enviorment set rabit_mock if os.getenv("RABIT_MOCK", None) is not None: args.append("-DRABIT_MOCK:BOOL=ON") + # if enviorment set GPU_ARCH_FLAG + gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None) + if gpu_arch_flag is not None: + args.append("%s" % gpu_arch_flag) + run("cmake .. " + " ".join(args) + maybe_generator) run("cmake --build . --config Release" + maybe_parallel_build) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 5e47ae5c5..04a11bc2b 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -38,6 +38,7 @@ 2.12.8 2.12 2.7.3 + OFF @@ -52,7 +53,65 @@ xgboost4j-spark xgboost4j-flink + + + + default + + true + + + + + org.scalatest + scalatest-maven-plugin + + ml.dmlc.xgboost4j.java.GpuTestSuite + + + + + + + + + gpu + + + use.cuda + ON + + + + + + org.scalatest + scalatest-maven-plugin + + + + + + + + gpu-with-gpu-tests + + ON + + + + + org.scalatest + scalatest-maven-plugin + + ml.dmlc.xgboost4j.java.GpuTestSuite + + + + + + release @@ -242,6 +301,25 @@ true + + + + + org.scalatest + scalatest-maven-plugin + 1.0 + + + test + + test + + + + + + + org.scalastyle @@ -336,15 +414,6 @@ org.scalatest scalatest-maven-plugin - 1.0 - - - test - - test - - - diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala index 9e1b02a71..6d676b0ae 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala @@ -31,8 +31,9 @@ object SparkMLlibPipeline { def main(args: Array[String]): Unit = { - if (args.length != 3) { - println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path") + if (args.length != 3 && args.length != 4) { + println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path " + + "[cpu|gpu]") sys.exit(1) } @@ -40,6 +41,10 @@ object SparkMLlibPipeline { val nativeModelPath = args(1) val pipelineModelPath = args(2) + val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") { + ("gpu_hist", 1) + } else ("auto", 2) + val spark = SparkSession .builder() .appName("XGBoost4J-Spark Pipeline Example") @@ -76,7 +81,8 @@ object SparkMLlibPipeline { "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 100, - "num_workers" -> 2 + "num_workers" -> numWorkers, + "tree_method" -> treeMethod ) ) booster.setFeaturesCol("features") diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala index d9375361b..a16d53c97 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala @@ -28,9 +28,14 @@ object SparkTraining { def main(args: Array[String]): Unit = { if (args.length < 1) { // scalastyle:off - println("Usage: program input_path") + println("Usage: program input_path [cpu|gpu]") sys.exit(1) } + + val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") { + ("gpu_hist", 1) + } else ("auto", 2) + val spark = SparkSession.builder().getOrCreate() val inputPath = args(0) val schema = new StructType(Array( @@ -68,7 +73,8 @@ object SparkTraining { "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 100, - "num_workers" -> 2, + "num_workers" -> numWorkers, + "tree_method" -> treeMethod, "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) val xgbClassifier = new XGBoostClassifier(xgbParam). setFeaturesCol("features"). diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index c0354866e..f1a6f13ef 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -22,7 +22,6 @@ import java.nio.file.Files import scala.collection.{AbstractIterator, mutable} import scala.util.Random import scala.collection.JavaConverters._ - import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams @@ -32,7 +31,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.FileSystem - import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener} import org.apache.spark.sql.SparkSession @@ -76,7 +74,9 @@ private[this] case class XGBoostExecutionParams( checkpointParam: Option[ExternalCheckpointParams], xgbInputParams: XGBoostExecutionInputParams, earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, - cacheTrainingSet: Boolean) { + cacheTrainingSet: Boolean, + treeMethod: Option[String], + isLocal: Boolean) { private var rawParamMap: Map[String, Any] = _ @@ -93,6 +93,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s private val logger = LogFactory.getLog("XGBoostSpark") + private val isLocal = sc.isLocal + private val overridedParams = overrideParams(rawParams, sc) /** @@ -168,11 +170,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s .getOrElse("allow_non_zero_for_missing", false) .asInstanceOf[Boolean] validateSparkSslConf + var treeMethod: Option[String] = None if (overridedParams.contains("tree_method")) { require(overridedParams("tree_method") == "hist" || overridedParams("tree_method") == "approx" || - overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" + - " 'hist', 'approx' and 'auto'") + overridedParams("tree_method") == "auto" || + overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" + + " as 'hist', 'approx', 'gpu_hist', and 'auto'") + treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) } if (overridedParams.contains("train_test_ratio")) { logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + @@ -221,7 +226,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s checkpointParam, inputParams, xgbExecEarlyStoppingParams, - cacheTrainingSet) + cacheTrainingSet, + treeMethod, + isLocal) xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } @@ -335,6 +342,26 @@ object XGBoost extends Serializable { } } + private def getGPUAddrFromResources: Int = { + val tc = TaskContext.get() + if (tc == null) { + throw new RuntimeException("Something wrong for task context") + } + val resources = tc.resources() + if (resources.contains("gpu")) { + val addrs = resources("gpu").addresses + if (addrs.size > 1) { + // TODO should we throw exception ? + logger.warn("XGBoost only supports 1 gpu per worker") + } + // take the first one + addrs.head.toInt + } else { + throw new RuntimeException("gpu is not allocated by spark, " + + "please check if gpu scheduling is enabled") + } + } + private def buildDistributedBooster( watches: Watches, xgbExecutionParam: XGBoostExecutionParams, @@ -362,13 +389,25 @@ object XGBoost extends Serializable { val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds)) val externalCheckpointParams = xgbExecutionParam.checkpointParam + + var params = xgbExecutionParam.toMap + if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) { + val gpuId = if (xgbExecutionParam.isLocal) { + // For local mode, force gpu id to primary device + 0 + } else { + getGPUAddrFromResources + } + logger.info("Leveraging gpu device " + gpuId + " to train") + params = params + ("gpu_id" -> gpuId) + } val booster = if (makeCheckpoint) { SXGBoost.trainAndSaveCheckpoint( - watches.toMap("train"), xgbExecutionParam.toMap, numRounds, + watches.toMap("train"), params, numRounds, watches.toMap, metrics, obj, eval, earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams) } else { - SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds, + SXGBoost.train(watches.toMap("train"), params, numRounds, watches.toMap, metrics, obj, eval, earlyStoppingRound = numEarlyStoppingRounds, prevBooster) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 1a7dd2a73..692a83630 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -145,11 +145,12 @@ private[spark] trait BoosterParams extends Params { final def getAlpha: Double = $(alpha) /** - * The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'} - * [default='auto'] + * The tree construction algorithm used in XGBoost. options: + * {'auto', 'exact', 'approx','gpu_hist'} [default='auto'] */ final val treeMethod = new Param[String](this, "treeMethod", - "The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}", + "The tree construction algorithm used in XGBoost, options: " + + "{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}", (value: String) => BoosterParams.supportedTreeMethods.contains(value)) final def getTreeMethod: String = $(treeMethod) @@ -292,7 +293,7 @@ private[spark] object BoosterParams { val supportedBoosters = HashSet("gbtree", "gblinear", "dart") - val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist") + val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist") val supportedGrowthPolicies = HashSet("depthwise", "lossguide") diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 5eee79ef2..dd9e32516 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -261,10 +261,10 @@ private[spark] trait ParamMapFuncs extends Params { for ((paramName, paramValue) <- xgboostParams) { if ((paramName == "booster" && paramValue != "gbtree") || (paramName == "updater" && paramValue != "grow_histmaker,prune" && - paramValue != "hist")) { + paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) { throw new IllegalArgumentException(s"you specified $paramName as $paramValue," + - s" XGBoost-Spark only supports gbtree as booster type" + - " and grow_histmaker,prune or hist as the updater type") + s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" + + s" grow_quantile_histmaker or grow_gpu_hist as the updater type") } val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) params.find(_.name == name).foreach { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index b1dda665b..4f7dfe6c9 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -16,82 +16,16 @@ package ml.dmlc.xgboost4j.scala.spark +import ml.dmlc.xgboost4j.java.GpuTestSuite import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.ml.linalg._ import org.apache.spark.sql._ import org.scalatest.FunSuite import org.apache.spark.Partitioner -class XGBoostClassifierSuite extends FunSuite with PerTest { +abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest { - test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") { - val trainingDM = new DMatrix(Classification.train.iterator) - val testDM = new DMatrix(Classification.test.iterator) - val trainingDF = buildDataFrame(Classification.train) - val testDF = buildDataFrame(Classification.test) - checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) - } - - test("XGBoostClassifier should make correct predictions after upstream random sort") { - val trainingDM = new DMatrix(Classification.train.iterator) - val testDM = new DMatrix(Classification.test.iterator) - val trainingDF = buildDataFrameWithRandSort(Classification.train) - val testDF = buildDataFrameWithRandSort(Classification.test) - checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) - } - - private def checkResultsWithXGBoost4j( - trainingDM: DMatrix, - testDM: DMatrix, - trainingDF: DataFrame, - testDF: DataFrame, - round: Int = 5): Unit = { - val paramMap = Map( - "eta" -> "1", - "max_depth" -> "6", - "silent" -> "1", - "objective" -> "binary:logistic") - - val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) - val prediction1 = model1.predict(testDM) - - val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round, - "num_workers" -> numWorkers)).fit(trainingDF) - - val prediction2 = model2.transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap - - assert(testDF.count() === prediction2.size) - // the vector length in probability column is 2 since we have to fit to the evaluator in Spark - for (i <- prediction1.indices) { - assert(prediction1(i).length === prediction2(i).values.length - 1) - for (j <- prediction1(i).indices) { - assert(prediction1(i)(j) === prediction2(i)(j + 1)) - } - } - - val prediction3 = model1.predict(testDM, outPutMargin = true) - val prediction4 = model2.transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap - - assert(testDF.count() === prediction4.size) - // the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark - for (i <- prediction3.indices) { - assert(prediction3(i).length === prediction4(i).values.length - 1) - for (j <- prediction3(i).indices) { - assert(prediction3(i)(j) === prediction4(i)(j + 1)) - } - } - - // check the equality of single instance prediction - val firstOfDM = testDM.slice(Array(0)) - val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0) - .head() - .getAs[Vector]("features") - val prediction5 = math.round(model1.predict(firstOfDM)(0)(0)) - val prediction6 = model2.predict(firstOfDF) - assert(prediction5 === prediction6) - } + protected val treeMethod: String = "auto" test("Set params in XGBoost and MLlib way should produce same model") { val trainingDF = buildDataFrame(Classification.train) @@ -104,6 +38,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { "silent" -> "1", "objective" -> "binary:logistic", "num_round" -> round, + "tree_method" -> treeMethod, "num_workers" -> numWorkers) // Set params in XGBoost way @@ -128,7 +63,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("test schema of XGBoostClassificationModel") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val trainingDF = buildDataFrame(Classification.train) val testDF = buildDataFrame(Classification.test) @@ -160,7 +96,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("multi class classification") { val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5, - "num_workers" -> numWorkers) + "num_workers" -> numWorkers, "tree_method" -> treeMethod) val trainingDF = buildDataFrame(MultiClassification.train) val xgb = new XGBoostClassifier(paramMap) val model = xgb.fit(trainingDF) @@ -175,7 +111,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { val test = buildDataFrame(Classification.test) val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "train_test_ratio" -> "1.0", - "num_round" -> 5, "num_workers" -> numWorkers) + "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod) val xgb = new XGBoostClassifier(paramMap) val model1 = xgb.fit(training1) @@ -194,7 +130,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("test predictionLeaf") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", - "num_round" -> 5, "num_workers" -> numWorkers) + "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod) val training = buildDataFrame(Classification.train) val test = buildDataFrame(Classification.test) val groundTruth = test.count() @@ -209,7 +145,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("test predictionLeaf with empty column name") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", - "num_round" -> 5, "num_workers" -> numWorkers) + "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod) val training = buildDataFrame(Classification.train) val test = buildDataFrame(Classification.test) val xgb = new XGBoostClassifier(paramMap) @@ -222,7 +158,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("test predictionContrib") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", - "num_round" -> 5, "num_workers" -> numWorkers) + "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod) val training = buildDataFrame(Classification.train) val test = buildDataFrame(Classification.test) val groundTruth = test.count() @@ -237,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("test predictionContrib with empty column name") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", - "num_round" -> 5, "num_workers" -> numWorkers) + "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod) val training = buildDataFrame(Classification.train) val test = buildDataFrame(Classification.test) val xgb = new XGBoostClassifier(paramMap) @@ -250,7 +186,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { test("test predictionLeaf and predictionContrib") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", - "num_round" -> 5, "num_workers" -> numWorkers) + "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod) val training = buildDataFrame(Classification.train) val test = buildDataFrame(Classification.test) val groundTruth = test.count() @@ -264,6 +200,80 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { assert(resultDF.columns.contains("predictContrib")) } +} + +class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase { + test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") { + val trainingDM = new DMatrix(Classification.train.iterator) + val testDM = new DMatrix(Classification.test.iterator) + val trainingDF = buildDataFrame(Classification.train) + val testDF = buildDataFrame(Classification.test) + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) + } + + test("XGBoostClassifier should make correct predictions after upstream random sort") { + val trainingDM = new DMatrix(Classification.train.iterator) + val testDM = new DMatrix(Classification.test.iterator) + val trainingDF = buildDataFrameWithRandSort(Classification.train) + val testDF = buildDataFrameWithRandSort(Classification.test) + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF) + } + + private def checkResultsWithXGBoost4j( + trainingDM: DMatrix, + testDM: DMatrix, + trainingDF: DataFrame, + testDF: DataFrame, + round: Int = 5): Unit = { + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "binary:logistic", + "tree_method" -> treeMethod, + "max_bin" -> 16) + + val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) + val prediction1 = model1.predict(testDM) + + val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round, + "num_workers" -> numWorkers)).fit(trainingDF) + + val prediction2 = model2.transform(testDF). + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap + + assert(testDF.count() === prediction2.size) + // the vector length in probability column is 2 since we have to fit to the evaluator in Spark + for (i <- prediction1.indices) { + assert(prediction1(i).length === prediction2(i).values.length - 1) + for (j <- prediction1(i).indices) { + assert(prediction1(i)(j) === prediction2(i)(j + 1)) + } + } + + val prediction3 = model1.predict(testDM, outPutMargin = true) + val prediction4 = model2.transform(testDF). + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap + + assert(testDF.count() === prediction4.size) + // the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark + for (i <- prediction3.indices) { + assert(prediction3(i).length === prediction4(i).values.length - 1) + for (j <- prediction3(i).indices) { + assert(prediction3(i)(j) === prediction4(i)(j + 1)) + } + } + + // check the equality of single instance prediction + val firstOfDM = testDM.slice(Array(0)) + val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0) + .head() + .getAs[Vector]("features") + val prediction5 = math.round(model1.predict(firstOfDM)(0)(0)) + val prediction6 = model2.predict(firstOfDF) + assert(prediction5 === prediction6) + } + test("infrequent features") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", @@ -305,5 +315,10 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { val xgb = new XGBoostClassifier(paramMap) xgb.fit(repartitioned) } - +} + +@GpuTestSuite +class XGBoostGpuClassifierSuite extends XGBoostClassifierSuiteBase { + override protected val treeMethod: String = "gpu_hist" + override protected val numWorkers: Int = 1 } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index fc1f8d906..7a8bf6fa4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -16,6 +16,7 @@ package ml.dmlc.xgboost4j.scala.spark +import ml.dmlc.xgboost4j.java.GpuTestSuite import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.functions._ @@ -23,7 +24,8 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types._ import org.scalatest.FunSuite -class XGBoostRegressorSuite extends FunSuite with PerTest { +abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest { + protected val treeMethod: String = "auto" test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") { val trainingDM = new DMatrix(Regression.train.iterator) @@ -51,7 +53,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { "eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror") + "objective" -> "reg:squarederror", + "max_bin" -> 16, + "tree_method" -> treeMethod) val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) val prediction1 = model1.predict(testDM) @@ -88,6 +92,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { "silent" -> "1", "objective" -> "reg:squarederror", "num_round" -> round, + "tree_method" -> treeMethod, "num_workers" -> numWorkers) // Set params in XGBoost way @@ -99,6 +104,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { .setSilent(1) .setObjective("reg:squarederror") .setNumRound(round) + .setTreeMethod(treeMethod) .setNumWorkers(numWorkers) .fit(trainingDF) @@ -113,7 +119,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("ranking: use group data") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5, - "group_col" -> "group") + "group_col" -> "group", "tree_method" -> treeMethod) val trainingDF = buildDataFrameWithGroup(Ranking.train) val testDF = buildDataFrame(Ranking.test) @@ -125,7 +131,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("use weight") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}) val trainingDF = buildDataFrame(Regression.train) @@ -140,7 +147,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("test predictionLeaf") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val training = buildDataFrame(Regression.train) val testDF = buildDataFrame(Regression.test) val groundTruth = testDF.count() @@ -154,7 +162,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("test predictionLeaf with empty column name") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val training = buildDataFrame(Regression.train) val testDF = buildDataFrame(Regression.test) val xgb = new XGBoostRegressor(paramMap) @@ -166,7 +175,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("test predictionContrib") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val training = buildDataFrame(Regression.train) val testDF = buildDataFrame(Regression.test) val groundTruth = testDF.count() @@ -180,7 +190,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("test predictionContrib with empty column name") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val training = buildDataFrame(Regression.train) val testDF = buildDataFrame(Regression.test) val xgb = new XGBoostRegressor(paramMap) @@ -192,7 +203,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { test("test predictionLeaf and predictionContrib") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers, + "tree_method" -> treeMethod) val training = buildDataFrame(Regression.train) val testDF = buildDataFrame(Regression.test) val groundTruth = testDF.count() @@ -206,3 +218,13 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { assert(resultDF.columns.contains("predictContrib")) } } + +class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase { + +} + +@GpuTestSuite +class XGBoostGpuRegressorSuite extends XGBoostRegressorSuiteBase { + override protected val treeMethod: String = "gpu_hist" + override protected val numWorkers: Int = 1 +} diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index cf5558058..62a96c41e 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -43,6 +43,12 @@ 2.5.23 test + + org.scalatest + scalatest_${scala.binary.version} + 3.0.5 + compile + @@ -78,6 +84,8 @@ python create_jni.py + --use-cuda + ${use.cuda} ${user.dir} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/GpuTestSuite.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/GpuTestSuite.java new file mode 100644 index 000000000..60d25dab8 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/GpuTestSuite.java @@ -0,0 +1,28 @@ +/* + Copyright (c) 2020 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.java; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface GpuTestSuite {} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 791f06b71..90d06c343 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -46,6 +46,7 @@ object XGBoost { } else { prevBooster.booster } + val xgboostInJava = checkpointParams. map(cp => { JXGBoost.trainAndSaveCheckpoint( diff --git a/tests/ci_build/Dockerfile.gpu_jvm b/tests/ci_build/Dockerfile.gpu_jvm new file mode 100644 index 000000000..acd7b9b86 --- /dev/null +++ b/tests/ci_build/Dockerfile.gpu_jvm @@ -0,0 +1,51 @@ +ARG CUDA_VERSION +FROM nvidia/cuda:$CUDA_VERSION-runtime-ubuntu16.04 +ARG JDK_VERSION=8 +ARG SPARK_VERSION=3.0.0 + +# Environment +ENV DEBIAN_FRONTEND noninteractive + +# Install all basic requirements +RUN \ + apt-get update && \ + apt-get install -y software-properties-common && \ + add-apt-repository ppa:openjdk-r/ppa && \ + apt-get update && \ + apt-get install -y tar unzip wget openjdk-$JDK_VERSION-jdk libgomp1 && \ + # Python + wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3.sh -b -p /opt/python && \ + /opt/python/bin/pip install awscli && \ + # Maven + wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ + tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.1/ /opt/maven && \ + # Spark + wget https://archive.apache.org/dist/spark/spark-$SPARK_VERSION/spark-$SPARK_VERSION-bin-hadoop2.7.tgz && \ + tar xvf spark-$SPARK_VERSION-bin-hadoop2.7.tgz -C /opt && \ + ln -s /opt/spark-$SPARK_VERSION-bin-hadoop2.7 /opt/spark + +ENV PATH=/opt/python/bin:/opt/spark/bin:/opt/maven/bin:$PATH + +# Install Python packages +RUN \ + pip install numpy scipy pandas scikit-learn + +ENV GOSU_VERSION 1.10 + +# Install lightweight sudo (not bound to TTY) +RUN set -ex; \ + wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \ + chmod +x /usr/local/bin/gosu && \ + gosu nobody true + +# Set default JDK version +RUN update-java-alternatives -v -s java-1.$JDK_VERSION.0-openjdk-amd64 + +# Default entry-point to use if running locally +# It will preserve attributes of created files +COPY entrypoint.sh /scripts/ + +WORKDIR /workspace +ENTRYPOINT ["/scripts/entrypoint.sh"] diff --git a/tests/ci_build/Dockerfile.jvm_gpu_build b/tests/ci_build/Dockerfile.jvm_gpu_build new file mode 100644 index 000000000..ed6c3d689 --- /dev/null +++ b/tests/ci_build/Dockerfile.jvm_gpu_build @@ -0,0 +1,63 @@ +ARG CUDA_VERSION +FROM nvidia/cuda:$CUDA_VERSION-devel-centos6 +ARG CUDA_VERSION + +# Environment +ENV DEBIAN_FRONTEND noninteractive +ENV DEVTOOLSET_URL_ROOT http://vault.centos.org/6.9/sclo/x86_64/rh/devtoolset-4/ + +# Install all basic requirements +RUN \ + yum -y update && \ + yum install -y tar unzip wget xz git centos-release-scl yum-utils java-1.8.0-openjdk-devel && \ + yum-config-manager --enable centos-sclo-rh-testing && \ + yum -y update && \ + yum install -y $DEVTOOLSET_URL_ROOT/devtoolset-4-gcc-5.3.1-6.1.el6.x86_64.rpm \ + $DEVTOOLSET_URL_ROOT/devtoolset-4-gcc-c++-5.3.1-6.1.el6.x86_64.rpm \ + $DEVTOOLSET_URL_ROOT/devtoolset-4-binutils-2.25.1-8.el6.x86_64.rpm \ + $DEVTOOLSET_URL_ROOT/devtoolset-4-runtime-4.1-3.sc1.el6.x86_64.rpm \ + $DEVTOOLSET_URL_ROOT/devtoolset-4-libstdc++-devel-5.3.1-6.1.el6.x86_64.rpm && \ + # Python + wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3.sh -b -p /opt/python && \ + # CMake + wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \ + bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ + # Maven + wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ + tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.1/ /opt/maven + +# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) +RUN \ + export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \ + export NCCL_VERSION=2.4.8-1 && \ + wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ + rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ + yum -y update && \ + yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \ + rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm; + +ENV PATH=/opt/python/bin:/opt/maven/bin:$PATH +ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc +ENV CXX=/opt/rh/devtoolset-4/root/usr/bin/c++ +ENV CPP=/opt/rh/devtoolset-4/root/usr/bin/cpp + +# Install Python packages +RUN \ + pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22 awscli + +ENV GOSU_VERSION 1.10 + +# Install lightweight sudo (not bound to TTY) +RUN set -ex; \ + wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \ + chmod +x /usr/local/bin/gosu && \ + gosu nobody true + +# Default entry-point to use if running locally +# It will preserve attributes of created files +COPY entrypoint.sh /scripts/ + +WORKDIR /workspace +ENTRYPOINT ["/scripts/entrypoint.sh"] diff --git a/tests/ci_build/build_jvm_packages.sh b/tests/ci_build/build_jvm_packages.sh index 8190aa1e1..dcb80a162 100755 --- a/tests/ci_build/build_jvm_packages.sh +++ b/tests/ci_build/build_jvm_packages.sh @@ -3,12 +3,15 @@ set -e set -x -if [ $# -ne 1 ]; then - echo "Usage: $0 [spark version]" - exit 1 -fi - spark_version=$1 +use_cuda=$2 +gpu_arch=$3 + +gpu_options="" +if [ "x$use_cuda" == "x-Duse.cuda=ON" ]; then + # Since building jvm for CPU will do unit tests, choose gpu-with-gpu-tests profile to build + gpu_options=" -Pgpu-with-gpu-tests " +fi # Initialize local Maven repository ./tests/ci_build/initialize_maven.sh @@ -16,7 +19,11 @@ spark_version=$1 rm -rf build/ cd jvm-packages export RABIT_MOCK=ON -mvn --no-transfer-progress package -Dspark.version=${spark_version} + +if [ "x$gpu_arch" != "x" ]; then + export GPU_ARCH_FLAG=$gpu_arch +fi +mvn --no-transfer-progress package -Dspark.version=${spark_version} $gpu_options set +x set +e diff --git a/tests/ci_build/test_jvm_gpu_cross.sh b/tests/ci_build/test_jvm_gpu_cross.sh new file mode 100755 index 000000000..51ccfa32b --- /dev/null +++ b/tests/ci_build/test_jvm_gpu_cross.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -e +set -x + + +nvidia-smi + +ls /usr/local/ + +# Initialize local Maven repository +./tests/ci_build/initialize_maven.sh + +# Get version number of XGBoost4J and other auxiliary information +cd jvm-packages +xgboost4j_version=$(mvn help:evaluate -Dexpression=project.version -q -DforceStdout) +scala_binary_version=$(mvn help:evaluate -Dexpression=scala.binary.version -q -DforceStdout) + +python3 xgboost4j-tester/get_iris.py +xgb_jars="./xgboost4j/target/xgboost4j_${scala_binary_version}-${xgboost4j_version}.jar,./xgboost4j-spark/target/xgboost4j-spark_${scala_binary_version}-${xgboost4j_version}.jar" +example_jar="./xgboost4j-example/target/xgboost4j-example_${scala_binary_version}-${xgboost4j_version}.jar" + +echo "Run SparkTraining locally ... " +spark-submit \ + --master 'local[1]' \ + --class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining \ + --jars $xgb_jars \ + $example_jar \ + ${PWD}/iris.csv gpu \ + +echo "Run SparkMLlibPipeline locally ... " +spark-submit \ + --master 'local[1]' \ + --class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline \ + --jars $xgb_jars \ + $example_jar \ + ${PWD}/iris.csv ${PWD}/native_model ${PWD}/pipeline_model gpu \ + +set +x +set +e