[jvm-packages] disable booster setup for xgboost4j-spark (#3456)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* disable booster setup in spark

* check in parameter conversion

* fix compilation issue

* update exception type
This commit is contained in:
Nan Zhu 2018-07-07 21:57:24 -07:00 committed by GitHub
parent 66e74d2223
commit aa90e5c6ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 16 deletions

View File

@ -82,9 +82,6 @@ class XGBoostClassifier (
def setSeed(value: Long): this.type = set(seed, value)
// setters for booster params
def setBooster(value: String): this.type = set(booster, value)
def setEta(value: Double): this.type = set(eta, value)
def setGamma(value: Double): this.type = set(gamma, value)

View File

@ -84,9 +84,6 @@ class XGBoostRegressor (
def setSeed(value: Long): this.type = set(seed, value)
// setters for booster params
def setBooster(value: String): this.type = set(booster, value)
def setEta(value: Double): this.type = set(eta, value)
def setGamma(value: Double): this.type = set(gamma, value)

View File

@ -22,15 +22,6 @@ import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
private[spark] trait BoosterParams extends Params {
/**
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
*/
final val booster = new Param[String](this, "booster",
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
final def getBooster: String = $(booster)
/**
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
* can directly get the weights of new features and eta actually shrinks the feature weights
@ -246,7 +237,7 @@ private[spark] trait BoosterParams extends Params {
final def getLambdaBias: Double = $(lambdaBias)
setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0,
growPolicy -> "depthwise", maxBins -> 16,
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,

View File

@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
import com.google.common.base.CaseFormat
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.apache.spark.ml.param._
import scala.collection.mutable
@ -198,6 +199,12 @@ private[spark] trait ParamMapFuncs extends Params {
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
if ((paramName == "booster" && paramValue != "gbtree") ||
(paramName == "updater" && paramValue != "grow_colmaker,prune")) {
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
s" XGBoost-Spark only supports gbtree as booster type" +
" and grow_colmaker,prune as the updater type")
}
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
params.find(_.name == name) match {
case None =>