[BLOCKING] [jvm-packages] add gpu_hist and enable gpu scheduling (#5171)
* [jvm-packages] add gpu_hist tree method * change updater hist to grow_quantile_histmaker * add gpu scheduling * pass correct parameters to xgboost library * remove debug info * add use.cuda for pom * add CI for gpu_hist for jvm * add gpu unit tests * use gpu node to build jvm * use nvidia-docker * Add CLI interface to create_jni.py using argparse Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -22,7 +22,6 @@ import java.nio.file.Files
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
@@ -32,7 +31,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
@@ -76,7 +74,9 @@ private[this] case class XGBoostExecutionParams(
|
||||
checkpointParam: Option[ExternalCheckpointParams],
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
cacheTrainingSet: Boolean) {
|
||||
cacheTrainingSet: Boolean,
|
||||
treeMethod: Option[String],
|
||||
isLocal: Boolean) {
|
||||
|
||||
private var rawParamMap: Map[String, Any] = _
|
||||
|
||||
@@ -93,6 +93,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private val isLocal = sc.isLocal
|
||||
|
||||
private val overridedParams = overrideParams(rawParams, sc)
|
||||
|
||||
/**
|
||||
@@ -168,11 +170,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
validateSparkSslConf
|
||||
var treeMethod: Option[String] = None
|
||||
if (overridedParams.contains("tree_method")) {
|
||||
require(overridedParams("tree_method") == "hist" ||
|
||||
overridedParams("tree_method") == "approx" ||
|
||||
overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" +
|
||||
" 'hist', 'approx' and 'auto'")
|
||||
overridedParams("tree_method") == "auto" ||
|
||||
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
|
||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||
}
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
@@ -221,7 +226,9 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
checkpointParam,
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
cacheTrainingSet)
|
||||
cacheTrainingSet,
|
||||
treeMethod,
|
||||
isLocal)
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
@@ -335,6 +342,26 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def getGPUAddrFromResources: Int = {
|
||||
val tc = TaskContext.get()
|
||||
if (tc == null) {
|
||||
throw new RuntimeException("Something wrong for task context")
|
||||
}
|
||||
val resources = tc.resources()
|
||||
if (resources.contains("gpu")) {
|
||||
val addrs = resources("gpu").addresses
|
||||
if (addrs.size > 1) {
|
||||
// TODO should we throw exception ?
|
||||
logger.warn("XGBoost only supports 1 gpu per worker")
|
||||
}
|
||||
// take the first one
|
||||
addrs.head.toInt
|
||||
} else {
|
||||
throw new RuntimeException("gpu is not allocated by spark, " +
|
||||
"please check if gpu scheduling is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
private def buildDistributedBooster(
|
||||
watches: Watches,
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
@@ -362,13 +389,25 @@ object XGBoost extends Serializable {
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
var params = xgbExecutionParam.toMap
|
||||
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
|
||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||
// For local mode, force gpu id to primary device
|
||||
0
|
||||
} else {
|
||||
getGPUAddrFromResources
|
||||
}
|
||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||
params = params + ("gpu_id" -> gpuId)
|
||||
}
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||
watches.toMap("train"), params, numRounds,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
||||
} else {
|
||||
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||
SXGBoost.train(watches.toMap("train"), params, numRounds,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
}
|
||||
|
||||
@@ -145,11 +145,12 @@ private[spark] trait BoosterParams extends Params {
|
||||
final def getAlpha: Double = $(alpha)
|
||||
|
||||
/**
|
||||
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
|
||||
* [default='auto']
|
||||
* The tree construction algorithm used in XGBoost. options:
|
||||
* {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
|
||||
*/
|
||||
final val treeMethod = new Param[String](this, "treeMethod",
|
||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
|
||||
"The tree construction algorithm used in XGBoost, options: " +
|
||||
"{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
@@ -292,7 +293,7 @@ private[spark] object BoosterParams {
|
||||
|
||||
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
|
||||
|
||||
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
|
||||
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
|
||||
|
||||
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
|
||||
|
||||
|
||||
@@ -261,10 +261,10 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||
paramValue != "hist")) {
|
||||
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type" +
|
||||
" and grow_histmaker,prune or hist as the updater type")
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
|
||||
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name).foreach {
|
||||
|
||||
Reference in New Issue
Block a user