[jvm-packages] Add the new device parameter. (#9385)

This commit is contained in:
Jiaming Yuan
2023-07-17 18:40:39 +08:00
committed by GitHub
parent 2caceb157d
commit f4fb2be101
15 changed files with 112 additions and 47 deletions

View File

@@ -73,7 +73,7 @@ private[scala] case class XGBoostExecutionParams(
xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean,
treeMethod: Option[String],
device: Option[String],
isLocal: Boolean,
featureNames: Option[Array[String]],
featureTypes: Option[Array[String]]) {
@@ -180,6 +180,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
}
val device: Option[String] = overridedParams.get("device") match {
case None => None
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
}
if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
@@ -228,7 +232,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
inputParams,
xgbExecEarlyStoppingParams,
cacheTrainingSet,
treeMethod,
device,
isLocal,
featureNames,
featureTypes
@@ -318,7 +322,7 @@ object XGBoost extends Serializable {
val externalCheckpointParams = xgbExecutionParam.checkpointParam
var params = xgbExecutionParam.toMap
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
val gpuId = if (xgbExecutionParam.isLocal) {
// For local mode, force gpu id to primary device
0
@@ -328,6 +332,7 @@ object XGBoost extends Serializable {
logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("device" -> s"cuda:$gpuId")
}
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), params, numRounds,

View File

@@ -93,6 +93,8 @@ class XGBoostClassifier (
def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setDevice(value: String): this.type = set(device, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)

View File

@@ -95,6 +95,8 @@ class XGBoostRegressor (
def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setDevice(value: String): this.type = set(device, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)

View File

@@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params {
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod)
/**
* The device for running XGBoost algorithms, options: cpu, cuda
*/
final val device = new Param[String](
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
)
final def getDevice: String = $(device)
/**
* growth policy for fast histogram algorithm

View File

@@ -284,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params {
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
}
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)

View File

@@ -469,7 +469,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
.setFeatureTypes(featureTypes)
val model = xgb.fit(trainingDF)
val modelStr = new String(model._booster.toByteArray("json"))
System.out.println(modelStr)
val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]