[jvm-packages] disable booster setup for xgboost4j-spark (#3456)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * disable booster setup in spark * check in parameter conversion * fix compilation issue * update exception type
This commit is contained in:
parent
66e74d2223
commit
aa90e5c6ce
@ -82,9 +82,6 @@ class XGBoostClassifier (
|
||||
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
// setters for booster params
|
||||
def setBooster(value: String): this.type = set(booster, value)
|
||||
|
||||
def setEta(value: Double): this.type = set(eta, value)
|
||||
|
||||
def setGamma(value: Double): this.type = set(gamma, value)
|
||||
|
||||
@ -84,9 +84,6 @@ class XGBoostRegressor (
|
||||
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
// setters for booster params
|
||||
def setBooster(value: String): this.type = set(booster, value)
|
||||
|
||||
def setEta(value: Double): this.type = set(eta, value)
|
||||
|
||||
def setGamma(value: Double): this.type = set(gamma, value)
|
||||
|
||||
@ -22,15 +22,6 @@ import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
private[spark] trait BoosterParams extends Params {
|
||||
|
||||
/**
|
||||
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||
*/
|
||||
final val booster = new Param[String](this, "booster",
|
||||
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
|
||||
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
|
||||
|
||||
final def getBooster: String = $(booster)
|
||||
|
||||
/**
|
||||
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
|
||||
* can directly get the weights of new features and eta actually shrinks the feature weights
|
||||
@ -246,7 +237,7 @@ private[spark] trait BoosterParams extends Params {
|
||||
|
||||
final def getLambdaBias: Double = $(lambdaBias)
|
||||
|
||||
setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
growPolicy -> "depthwise", maxBins -> 16,
|
||||
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
||||
|
||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import com.google.common.base.CaseFormat
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
import scala.collection.mutable
|
||||
|
||||
@ -198,6 +199,12 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||
(paramName == "updater" && paramValue != "grow_colmaker,prune")) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type" +
|
||||
" and grow_colmaker,prune as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name) match {
|
||||
case None =>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user