[jvm-packages] Add the new device parameter. (#9385)
This commit is contained in:
@@ -73,7 +73,7 @@ private[scala] case class XGBoostExecutionParams(
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
cacheTrainingSet: Boolean,
|
||||
treeMethod: Option[String],
|
||||
device: Option[String],
|
||||
isLocal: Boolean,
|
||||
featureNames: Option[Array[String]],
|
||||
featureTypes: Option[Array[String]]) {
|
||||
@@ -180,6 +180,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||
}
|
||||
val device: Option[String] = overridedParams.get("device") match {
|
||||
case None => None
|
||||
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
|
||||
}
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
@@ -228,7 +232,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
cacheTrainingSet,
|
||||
treeMethod,
|
||||
device,
|
||||
isLocal,
|
||||
featureNames,
|
||||
featureTypes
|
||||
@@ -318,7 +322,7 @@ object XGBoost extends Serializable {
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
var params = xgbExecutionParam.toMap
|
||||
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
|
||||
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
|
||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||
// For local mode, force gpu id to primary device
|
||||
0
|
||||
@@ -328,6 +332,7 @@ object XGBoost extends Serializable {
|
||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||
params = params + ("device" -> s"cuda:$gpuId")
|
||||
}
|
||||
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
watches.toMap("train"), params, numRounds,
|
||||
|
||||
@@ -93,6 +93,8 @@ class XGBoostClassifier (
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setDevice(value: String): this.type = set(device, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
@@ -95,6 +95,8 @@ class XGBoostRegressor (
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setDevice(value: String): this.type = set(device, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
@@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params {
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
/**
|
||||
* The device for running XGBoost algorithms, options: cpu, cuda
|
||||
*/
|
||||
final val device = new Param[String](
|
||||
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
|
||||
)
|
||||
|
||||
final def getDevice: String = $(device)
|
||||
|
||||
/**
|
||||
* growth policy for fast histogram algorithm
|
||||
|
||||
@@ -284,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
|
||||
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
|
||||
@@ -469,7 +469,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
||||
.setFeatureTypes(featureTypes)
|
||||
val model = xgb.fit(trainingDF)
|
||||
val modelStr = new String(model._booster.toByteArray("json"))
|
||||
System.out.println(modelStr)
|
||||
val jsonModel = parseJson(modelStr)
|
||||
implicit val formats: Formats = DefaultFormats
|
||||
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||
|
||||
Reference in New Issue
Block a user