[BLOCKING] [jvm-packages] add gpu_hist and enable gpu scheduling (#5171)

* [jvm-packages] add gpu_hist tree method

* change updater hist to grow_quantile_histmaker

* add gpu scheduling

* pass correct parameters to xgboost library

* remove debug info

* add use.cuda for pom

* add CI for gpu_hist for jvm

* add gpu unit tests

* use gpu node to build jvm

* use nvidia-docker

* Add CLI interface to create_jni.py using argparse

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Bobby Wang 2020-07-27 12:53:24 +08:00 committed by GitHub
parent 6347fa1c2e
commit 8943eb4314
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 543 additions and 123 deletions

42
Jenkinsfile vendored
View File

@ -75,6 +75,7 @@ pipeline {
'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') },
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') },
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') },
'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') },
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
'build-jvm-doc': { BuildJVMDoc() }
])
@ -94,6 +95,7 @@ pipeline {
'test-python-mgpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2', multi_gpu: true) },
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') },
'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') },
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') },
'test-jvm-jdk11': { CrossTestJVMwithJDK(jdk_version: '11') },
'test-jvm-jdk12': { CrossTestJVMwithJDK(jdk_version: '12') },
@ -282,6 +284,28 @@ def BuildCUDA(args) {
}
}
def BuildJVMPackagesWithCUDA(args) {
node('linux && gpu') {
unstash name: 'srcs'
echo "Build XGBoost4J-Spark with Spark ${args.spark_version}, CUDA ${args.cuda_version}"
def container_type = "jvm_gpu_build"
def docker_binary = "nvidia-docker"
def docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}"
def arch_flag = ""
if (env.BRANCH_NAME != 'master' && !(env.BRANCH_NAME.startsWith('release'))) {
arch_flag = "-DGPU_COMPUTE_VER=75"
}
// Use only 4 CPU cores
def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--cpuset-cpus 0-3'"
sh """
${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_jvm_packages.sh ${args.spark_version} -Duse.cuda=ON $arch_flag
"""
echo "Stashing XGBoost4J JAR with CUDA ${args.cuda_version} ..."
stash name: 'xgboost4j_jar_gpu', includes: "jvm-packages/xgboost4j/target/*.jar,jvm-packages/xgboost4j-spark/target/*.jar,jvm-packages/xgboost4j-example/target/*.jar"
deleteDir()
}
}
def BuildJVMPackages(args) {
node('linux && cpu') {
unstash name: 'srcs'
@ -386,6 +410,24 @@ def TestCppGPU(args) {
}
}
def CrossTestJVMwithJDKGPU(args) {
def nodeReq = 'linux && mgpu'
node(nodeReq) {
unstash name: "xgboost4j_jar_gpu"
unstash name: 'srcs'
if (args.spark_version != null) {
echo "Test XGBoost4J on a machine with JDK ${args.jdk_version}, Spark ${args.spark_version}, CUDA ${args.host_cuda_version}"
} else {
echo "Test XGBoost4J on a machine with JDK ${args.jdk_version}, CUDA ${args.host_cuda_version}"
}
def container_type = "gpu_jvm"
def docker_binary = "nvidia-docker"
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_jvm_gpu_cross.sh"
deleteDir()
}
}
def CrossTestJVMwithJDK(args) {
node('linux && cpu') {
unstash name: 'xgboost4j_jar'

View File

@ -202,6 +202,14 @@ If you are on Mac OS and using a compiler that supports OpenMP, you need to go t
in order to get the benefit of multi-threading.
Building with GPU support
-------------------------
If you want to build XGBoost4J that supports distributed GPU training, run
.. code-block:: bash
mvn -Duse.cuda=ON install
********
Contents
********

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python
import errno
import argparse
import glob
import os
import shutil
@ -7,7 +8,6 @@ import subprocess
import sys
from contextlib import contextmanager
# Monkey-patch the API inconsistency between Python2.X and 3.X.
if sys.platform.startswith("linux"):
sys.platform = "linux"
@ -20,6 +20,7 @@ CONFIG = {
"USE_S3": "OFF",
"USE_CUDA": "OFF",
"USE_NCCL": "OFF",
"JVM_BINDINGS": "ON"
}
@ -68,6 +69,10 @@ def normpath(path):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--use-cuda', type=str, choices=['ON', 'OFF'], default='OFF')
cli_args = parser.parse_args()
if sys.platform == "darwin":
# Enable of your compiler supports OpenMP.
CONFIG["USE_OPENMP"] = "OFF"
@ -88,12 +93,21 @@ if __name__ == "__main__":
else:
maybe_parallel_build = ""
if cli_args.use_cuda == 'ON':
CONFIG['USE_CUDA'] = 'ON'
CONFIG['USE_NCCL'] = 'ON'
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
# if enviorment set rabit_mock
if os.getenv("RABIT_MOCK", None) is not None:
args.append("-DRABIT_MOCK:BOOL=ON")
# if enviorment set GPU_ARCH_FLAG
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
if gpu_arch_flag is not None:
args.append("%s" % gpu_arch_flag)
run("cmake .. " + " ".join(args) + maybe_generator)
run("cmake --build . --config Release" + maybe_parallel_build)

View File

@ -38,6 +38,7 @@
<scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>2.7.3</hadoop.version>
<use.cuda>OFF</use.cuda>
</properties>
<repositories>
<repository>
@ -52,7 +53,65 @@
<module>xgboost4j-spark</module>
<module>xgboost4j-flink</module>
</modules>
<profiles>
<profile>
<!-- default active profile excluding gpu related test suites -->
<id>default</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<configuration>
<tagsToExclude>ml.dmlc.xgboost4j.java.GpuTestSuite</tagsToExclude>
</configuration>
</plugin>
</plugins>
</build>
</profile>
<!-- gpu profile with both cpu and gpu test suites -->
<profile>
<id>gpu</id>
<activation>
<property>
<name>use.cuda</name>
<value>ON</value>
</property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</profile>
<!-- gpu-with-gpu-tests profile with only gpu test suites -->
<profile>
<id>gpu-with-gpu-tests</id>
<properties>
<use.cuda>ON</use.cuda>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<configuration>
<tagsToInclude>ml.dmlc.xgboost4j.java.GpuTestSuite</tagsToInclude>
</configuration>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>release</id>
<build>
@ -242,6 +301,25 @@
<filtering>true</filtering>
</resource>
</resources>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<version>1.0</version>
<executions>
<execution>
<id>test</id>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</pluginManagement>
<plugins>
<plugin>
<groupId>org.scalastyle</groupId>
@ -336,15 +414,6 @@
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<version>1.0</version>
<executions>
<execution>
<id>test</id>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
<extensions>

View File

@ -31,8 +31,9 @@ object SparkMLlibPipeline {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
if (args.length != 3 && args.length != 4) {
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path " +
"[cpu|gpu]")
sys.exit(1)
}
@ -40,6 +41,10 @@ object SparkMLlibPipeline {
val nativeModelPath = args(1)
val pipelineModelPath = args(2)
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
("gpu_hist", 1)
} else ("auto", 2)
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
@ -76,7 +81,8 @@ object SparkMLlibPipeline {
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2
"num_workers" -> numWorkers,
"tree_method" -> treeMethod
)
)
booster.setFeaturesCol("features")

View File

@ -28,9 +28,14 @@ object SparkTraining {
def main(args: Array[String]): Unit = {
if (args.length < 1) {
// scalastyle:off
println("Usage: program input_path")
println("Usage: program input_path [cpu|gpu]")
sys.exit(1)
}
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
("gpu_hist", 1)
} else ("auto", 2)
val spark = SparkSession.builder().getOrCreate()
val inputPath = args(0)
val schema = new StructType(Array(
@ -68,7 +73,8 @@ object SparkTraining {
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2,
"num_workers" -> numWorkers,
"tree_method" -> treeMethod,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").

View File

@ -22,7 +22,6 @@ import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable}
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
@ -32,7 +31,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
import org.apache.spark.sql.SparkSession
@ -76,7 +74,9 @@ private[this] case class XGBoostExecutionParams(
checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean) {
cacheTrainingSet: Boolean,
treeMethod: Option[String],
isLocal: Boolean) {
private var rawParamMap: Map[String, Any] = _
@ -93,6 +93,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
private val logger = LogFactory.getLog("XGBoostSpark")
private val isLocal = sc.isLocal
private val overridedParams = overrideParams(rawParams, sc)
/**
@ -168,11 +170,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean]
validateSparkSslConf
var treeMethod: Option[String] = None
if (overridedParams.contains("tree_method")) {
require(overridedParams("tree_method") == "hist" ||
overridedParams("tree_method") == "approx" ||
overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" +
" 'hist', 'approx' and 'auto'")
overridedParams("tree_method") == "auto" ||
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
}
if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
@ -221,7 +226,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
checkpointParam,
inputParams,
xgbExecEarlyStoppingParams,
cacheTrainingSet)
cacheTrainingSet,
treeMethod,
isLocal)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
@ -335,6 +342,26 @@ object XGBoost extends Serializable {
}
}
private def getGPUAddrFromResources: Int = {
val tc = TaskContext.get()
if (tc == null) {
throw new RuntimeException("Something wrong for task context")
}
val resources = tc.resources()
if (resources.contains("gpu")) {
val addrs = resources("gpu").addresses
if (addrs.size > 1) {
// TODO should we throw exception ?
logger.warn("XGBoost only supports 1 gpu per worker")
}
// take the first one
addrs.head.toInt
} else {
throw new RuntimeException("gpu is not allocated by spark, " +
"please check if gpu scheduling is enabled")
}
}
private def buildDistributedBooster(
watches: Watches,
xgbExecutionParam: XGBoostExecutionParams,
@ -362,13 +389,25 @@ object XGBoost extends Serializable {
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
val externalCheckpointParams = xgbExecutionParam.checkpointParam
var params = xgbExecutionParam.toMap
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
val gpuId = if (xgbExecutionParam.isLocal) {
// For local mode, force gpu id to primary device
0
} else {
getGPUAddrFromResources
}
logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("gpu_id" -> gpuId)
}
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap("train"), params, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
SXGBoost.train(watches.toMap("train"), params, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
}

View File

@ -145,11 +145,12 @@ private[spark] trait BoosterParams extends Params {
final def getAlpha: Double = $(alpha)
/**
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
* [default='auto']
* The tree construction algorithm used in XGBoost. options:
* {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
*/
final val treeMethod = new Param[String](this, "treeMethod",
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
"The tree construction algorithm used in XGBoost, options: " +
"{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod)
@ -292,7 +293,7 @@ private[spark] object BoosterParams {
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")

View File

@ -261,10 +261,10 @@ private[spark] trait ParamMapFuncs extends Params {
for ((paramName, paramValue) <- xgboostParams) {
if ((paramName == "booster" && paramValue != "gbtree") ||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
paramValue != "hist")) {
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
s" XGBoost-Spark only supports gbtree as booster type" +
" and grow_histmaker,prune or hist as the updater type")
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
}
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
params.find(_.name == name).foreach {

View File

@ -16,82 +16,16 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.GpuTestSuite
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.spark.Partitioner
class XGBoostClassifierSuite extends FunSuite with PerTest {
abstract class XGBoostClassifierSuiteBase extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
test("XGBoostClassifier should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Classification.train)
val testDF = buildDataFrameWithRandSort(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction1.indices) {
assert(prediction1(i).length === prediction2(i).values.length - 1)
for (j <- prediction1(i).indices) {
assert(prediction1(i)(j) === prediction2(i)(j + 1))
}
}
val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
assert(testDF.count() === prediction4.size)
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction3.indices) {
assert(prediction3(i).length === prediction4(i).values.length - 1)
for (j <- prediction3(i).indices) {
assert(prediction3(i)(j) === prediction4(i)(j + 1))
}
}
// check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
val prediction6 = model2.predict(firstOfDF)
assert(prediction5 === prediction6)
}
protected val treeMethod: String = "auto"
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Classification.train)
@ -104,6 +38,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> round,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
// Set params in XGBoost way
@ -128,7 +63,8 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
@ -160,7 +96,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("multi class classification") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers)
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
@ -175,7 +111,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val test = buildDataFrame(Classification.test)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
"num_round" -> 5, "num_workers" -> numWorkers)
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val xgb = new XGBoostClassifier(paramMap)
val model1 = xgb.fit(training1)
@ -194,7 +130,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
@ -209,7 +145,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("test predictionLeaf with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val xgb = new XGBoostClassifier(paramMap)
@ -222,7 +158,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("test predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
@ -237,7 +173,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("test predictionContrib with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val xgb = new XGBoostClassifier(paramMap)
@ -250,7 +186,7 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
test("test predictionLeaf and predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
"num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
@ -264,6 +200,80 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(resultDF.columns.contains("predictContrib"))
}
}
class XGBoostCpuClassifierSuite extends XGBoostClassifierSuiteBase {
test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
test("XGBoostClassifier should make correct predictions after upstream random sort") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrameWithRandSort(Classification.train)
val testDF = buildDataFrameWithRandSort(Classification.test)
checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
}
private def checkResultsWithXGBoost4j(
trainingDM: DMatrix,
testDM: DMatrix,
trainingDF: DataFrame,
testDF: DataFrame,
round: Int = 5): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"tree_method" -> treeMethod,
"max_bin" -> 16)
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction1.indices) {
assert(prediction1(i).length === prediction2(i).values.length - 1)
for (j <- prediction1(i).indices) {
assert(prediction1(i)(j) === prediction2(i)(j + 1))
}
}
val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
assert(testDF.count() === prediction4.size)
// the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction3.indices) {
assert(prediction3(i).length === prediction4(i).values.length - 1)
for (j <- prediction3(i).indices) {
assert(prediction3(i)(j) === prediction4(i)(j + 1))
}
}
// check the equality of single instance prediction
val firstOfDM = testDM.slice(Array(0))
val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
.head()
.getAs[Vector]("features")
val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
val prediction6 = model2.predict(firstOfDF)
assert(prediction5 === prediction6)
}
test("infrequent features") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
@ -305,5 +315,10 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
}
}
@GpuTestSuite
class XGBoostGpuClassifierSuite extends XGBoostClassifierSuiteBase {
override protected val treeMethod: String = "gpu_hist"
override protected val numWorkers: Int = 1
}

View File

@ -16,6 +16,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.GpuTestSuite
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions._
@ -23,7 +24,8 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
class XGBoostRegressorSuite extends FunSuite with PerTest {
abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
val trainingDM = new DMatrix(Regression.train.iterator)
@ -51,7 +53,9 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror")
"objective" -> "reg:squarederror",
"max_bin" -> 16,
"tree_method" -> treeMethod)
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
@ -88,6 +92,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> round,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)
// Set params in XGBoost way
@ -99,6 +104,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
.setSilent(1)
.setObjective("reg:squarederror")
.setNumRound(round)
.setTreeMethod(treeMethod)
.setNumWorkers(numWorkers)
.fit(trainingDF)
@ -113,7 +119,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("ranking: use group data") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group")
"group_col" -> "group", "tree_method" -> treeMethod)
val trainingDF = buildDataFrameWithGroup(Ranking.train)
val testDF = buildDataFrame(Ranking.test)
@ -125,7 +131,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("use weight") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
val trainingDF = buildDataFrame(Regression.train)
@ -140,7 +147,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
@ -154,7 +162,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("test predictionLeaf with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgb = new XGBoostRegressor(paramMap)
@ -166,7 +175,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("test predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
@ -180,7 +190,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("test predictionContrib with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgb = new XGBoostRegressor(paramMap)
@ -192,7 +203,8 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
test("test predictionLeaf and predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"tree_method" -> treeMethod)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
@ -206,3 +218,13 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
assert(resultDF.columns.contains("predictContrib"))
}
}
class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase {
}
@GpuTestSuite
class XGBoostGpuRegressorSuite extends XGBoostRegressorSuiteBase {
override protected val treeMethod: String = "gpu_hist"
override protected val numWorkers: Int = 1
}

View File

@ -43,6 +43,12 @@
<version>2.5.23</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.5</version>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
@ -78,6 +84,8 @@
<executable>python</executable>
<arguments>
<argument>create_jni.py</argument>
<argument>--use-cuda</argument>
<argument>${use.cuda}</argument>
</arguments>
<workingDirectory>${user.dir}</workingDirectory>
</configuration>

View File

@ -0,0 +1,28 @@
/*
Copyright (c) 2020 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import org.scalatest.TagAnnotation;
@TagAnnotation
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
public @interface GpuTestSuite {}

View File

@ -46,6 +46,7 @@ object XGBoost {
} else {
prevBooster.booster
}
val xgboostInJava = checkpointParams.
map(cp => {
JXGBoost.trainAndSaveCheckpoint(

View File

@ -0,0 +1,51 @@
ARG CUDA_VERSION
FROM nvidia/cuda:$CUDA_VERSION-runtime-ubuntu16.04
ARG JDK_VERSION=8
ARG SPARK_VERSION=3.0.0
# Environment
ENV DEBIAN_FRONTEND noninteractive
# Install all basic requirements
RUN \
apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:openjdk-r/ppa && \
apt-get update && \
apt-get install -y tar unzip wget openjdk-$JDK_VERSION-jdk libgomp1 && \
# Python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python && \
/opt/python/bin/pip install awscli && \
# Maven
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \
ln -s /opt/apache-maven-3.6.1/ /opt/maven && \
# Spark
wget https://archive.apache.org/dist/spark/spark-$SPARK_VERSION/spark-$SPARK_VERSION-bin-hadoop2.7.tgz && \
tar xvf spark-$SPARK_VERSION-bin-hadoop2.7.tgz -C /opt && \
ln -s /opt/spark-$SPARK_VERSION-bin-hadoop2.7 /opt/spark
ENV PATH=/opt/python/bin:/opt/spark/bin:/opt/maven/bin:$PATH
# Install Python packages
RUN \
pip install numpy scipy pandas scikit-learn
ENV GOSU_VERSION 1.10
# Install lightweight sudo (not bound to TTY)
RUN set -ex; \
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
chmod +x /usr/local/bin/gosu && \
gosu nobody true
# Set default JDK version
RUN update-java-alternatives -v -s java-1.$JDK_VERSION.0-openjdk-amd64
# Default entry-point to use if running locally
# It will preserve attributes of created files
COPY entrypoint.sh /scripts/
WORKDIR /workspace
ENTRYPOINT ["/scripts/entrypoint.sh"]

View File

@ -0,0 +1,63 @@
ARG CUDA_VERSION
FROM nvidia/cuda:$CUDA_VERSION-devel-centos6
ARG CUDA_VERSION
# Environment
ENV DEBIAN_FRONTEND noninteractive
ENV DEVTOOLSET_URL_ROOT http://vault.centos.org/6.9/sclo/x86_64/rh/devtoolset-4/
# Install all basic requirements
RUN \
yum -y update && \
yum install -y tar unzip wget xz git centos-release-scl yum-utils java-1.8.0-openjdk-devel && \
yum-config-manager --enable centos-sclo-rh-testing && \
yum -y update && \
yum install -y $DEVTOOLSET_URL_ROOT/devtoolset-4-gcc-5.3.1-6.1.el6.x86_64.rpm \
$DEVTOOLSET_URL_ROOT/devtoolset-4-gcc-c++-5.3.1-6.1.el6.x86_64.rpm \
$DEVTOOLSET_URL_ROOT/devtoolset-4-binutils-2.25.1-8.el6.x86_64.rpm \
$DEVTOOLSET_URL_ROOT/devtoolset-4-runtime-4.1-3.sc1.el6.x86_64.rpm \
$DEVTOOLSET_URL_ROOT/devtoolset-4-libstdc++-devel-5.3.1-6.1.el6.x86_64.rpm && \
# Python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python && \
# CMake
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
# Maven
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \
ln -s /opt/apache-maven-3.6.1/ /opt/maven
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
RUN \
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
export NCCL_VERSION=2.4.8-1 && \
wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
yum -y update && \
yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \
rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm;
ENV PATH=/opt/python/bin:/opt/maven/bin:$PATH
ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc
ENV CXX=/opt/rh/devtoolset-4/root/usr/bin/c++
ENV CPP=/opt/rh/devtoolset-4/root/usr/bin/cpp
# Install Python packages
RUN \
pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22 awscli
ENV GOSU_VERSION 1.10
# Install lightweight sudo (not bound to TTY)
RUN set -ex; \
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
chmod +x /usr/local/bin/gosu && \
gosu nobody true
# Default entry-point to use if running locally
# It will preserve attributes of created files
COPY entrypoint.sh /scripts/
WORKDIR /workspace
ENTRYPOINT ["/scripts/entrypoint.sh"]

View File

@ -3,12 +3,15 @@
set -e
set -x
if [ $# -ne 1 ]; then
echo "Usage: $0 [spark version]"
exit 1
fi
spark_version=$1
use_cuda=$2
gpu_arch=$3
gpu_options=""
if [ "x$use_cuda" == "x-Duse.cuda=ON" ]; then
# Since building jvm for CPU will do unit tests, choose gpu-with-gpu-tests profile to build
gpu_options=" -Pgpu-with-gpu-tests "
fi
# Initialize local Maven repository
./tests/ci_build/initialize_maven.sh
@ -16,7 +19,11 @@ spark_version=$1
rm -rf build/
cd jvm-packages
export RABIT_MOCK=ON
mvn --no-transfer-progress package -Dspark.version=${spark_version}
if [ "x$gpu_arch" != "x" ]; then
export GPU_ARCH_FLAG=$gpu_arch
fi
mvn --no-transfer-progress package -Dspark.version=${spark_version} $gpu_options
set +x
set +e

View File

@ -0,0 +1,40 @@
#!/bin/bash
set -e
set -x
nvidia-smi
ls /usr/local/
# Initialize local Maven repository
./tests/ci_build/initialize_maven.sh
# Get version number of XGBoost4J and other auxiliary information
cd jvm-packages
xgboost4j_version=$(mvn help:evaluate -Dexpression=project.version -q -DforceStdout)
scala_binary_version=$(mvn help:evaluate -Dexpression=scala.binary.version -q -DforceStdout)
python3 xgboost4j-tester/get_iris.py
xgb_jars="./xgboost4j/target/xgboost4j_${scala_binary_version}-${xgboost4j_version}.jar,./xgboost4j-spark/target/xgboost4j-spark_${scala_binary_version}-${xgboost4j_version}.jar"
example_jar="./xgboost4j-example/target/xgboost4j-example_${scala_binary_version}-${xgboost4j_version}.jar"
echo "Run SparkTraining locally ... "
spark-submit \
--master 'local[1]' \
--class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining \
--jars $xgb_jars \
$example_jar \
${PWD}/iris.csv gpu \
echo "Run SparkMLlibPipeline locally ... "
spark-submit \
--master 'local[1]' \
--class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline \
--jars $xgb_jars \
$example_jar \
${PWD}/iris.csv ${PWD}/native_model ${PWD}/pipeline_model gpu \
set +x
set +e