diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index dfcedd3ed..c2e53e9a4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -82,9 +82,6 @@ class XGBoostClassifier ( def setSeed(value: Long): this.type = set(seed, value) - // setters for booster params - def setBooster(value: String): this.type = set(booster, value) - def setEta(value: Double): this.type = set(eta, value) def setGamma(value: Double): this.type = set(gamma, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index dcc6e534a..93c4b7446 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -84,9 +84,6 @@ class XGBoostRegressor ( def setSeed(value: Long): this.type = set(seed, value) - // setters for booster params - def setBooster(value: String): this.type = set(booster, value) - def setEta(value: Double): this.type = set(eta, value) def setGamma(value: Double): this.type = set(gamma, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 0bc54fd6f..c3be3601b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -22,15 +22,6 @@ import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params} private[spark] trait BoosterParams extends Params { - /** - * Booster to use, options: {'gbtree', 'gblinear', 'dart'} - */ - final val booster = new Param[String](this, "booster", - s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}", - (value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase)) - - final def getBooster: String = $(booster) - /** * step size shrinkage used in update to prevents overfitting. After each boosting step, we * can directly get the weights of new features and eta actually shrinks the feature weights @@ -246,7 +237,7 @@ private[spark] trait BoosterParams extends Params { final def getLambdaBias: Double = $(lambdaBias) - setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6, + setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0, growPolicy -> "depthwise", maxBins -> 16, subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 4dc9e7a39..6a27195b8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params import com.google.common.base.CaseFormat import ml.dmlc.xgboost4j.scala.spark.TrackerConf + import org.apache.spark.ml.param._ import scala.collection.mutable @@ -198,6 +199,12 @@ private[spark] trait ParamMapFuncs extends Params { def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = { for ((paramName, paramValue) <- xgboostParams) { + if ((paramName == "booster" && paramValue != "gbtree") || + (paramName == "updater" && paramValue != "grow_colmaker,prune")) { + throw new IllegalArgumentException(s"you specified $paramName as $paramValue," + + s" XGBoost-Spark only supports gbtree as booster type" + + " and grow_colmaker,prune as the updater type") + } val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) params.find(_.name == name) match { case None =>