[jvm-packages] bring back camel case variants of parameters (#10845)
This commit is contained in:
parent
2179baa50c
commit
a049490cdb
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import com.google.common.base.CaseFormat
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
||||||
@ -28,20 +29,23 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
|||||||
* @param xgboostParams XGBoost style parameters
|
* @param xgboostParams XGBoost style parameters
|
||||||
*/
|
*/
|
||||||
def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
|
def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
|
||||||
for ((name, paramValue) <- xgboostParams) {
|
for ((paramName, paramValue) <- xgboostParams) {
|
||||||
params.find(_.name == name).foreach {
|
val lowerCamelName = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||||
case _: DoubleParam =>
|
val lowerName = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, paramName)
|
||||||
set(name, paramValue.toString.toDouble)
|
val qualifiedNames = mutable.Set(paramName, lowerName, lowerCamelName)
|
||||||
case _: BooleanParam =>
|
params.find(p => qualifiedNames.contains(p.name)) foreach {
|
||||||
set(name, paramValue.toString.toBoolean)
|
case p: DoubleParam =>
|
||||||
case _: IntParam =>
|
set(p.name, paramValue.toString.toDouble)
|
||||||
set(name, paramValue.toString.toInt)
|
case p: BooleanParam =>
|
||||||
case _: FloatParam =>
|
set(p.name, paramValue.toString.toBoolean)
|
||||||
set(name, paramValue.toString.toFloat)
|
case p: IntParam =>
|
||||||
case _: LongParam =>
|
set(p.name, paramValue.toString.toInt)
|
||||||
set(name, paramValue.toString.toLong)
|
case p: FloatParam =>
|
||||||
case _: Param[_] =>
|
set(p.name, paramValue.toString.toFloat)
|
||||||
set(name, paramValue)
|
case p: LongParam =>
|
||||||
|
set(p.name, paramValue.toString.toLong)
|
||||||
|
case p: Param[_] =>
|
||||||
|
set(p.name, paramValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,7 +53,7 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
|||||||
/**
|
/**
|
||||||
* Convert the user-supplied parameters to the XGBoost parameters.
|
* Convert the user-supplied parameters to the XGBoost parameters.
|
||||||
*
|
*
|
||||||
* Note that this also contains jvm-specific parameters.
|
* Note that this doesn't contain jvm-specific parameters.
|
||||||
*/
|
*/
|
||||||
def getXGBoostParams: Map[String, Any] = {
|
def getXGBoostParams: Map[String, Any] = {
|
||||||
val xgboostParams = new mutable.HashMap[String, Any]()
|
val xgboostParams = new mutable.HashMap[String, Any]()
|
||||||
|
|||||||
@ -69,6 +69,48 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
|
|||||||
assert(model.getContribPredictionCol === "contrib")
|
assert(model.getContribPredictionCol === "contrib")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("camel case parameters") {
|
||||||
|
val xgbParams: Map[String, Any] = Map(
|
||||||
|
"max_depth" -> 5,
|
||||||
|
"featuresCol" -> "abc",
|
||||||
|
"num_workers" -> 2,
|
||||||
|
"numRound" -> 11
|
||||||
|
)
|
||||||
|
val estimator = new XGBoostClassifier(xgbParams)
|
||||||
|
assert(estimator.getFeaturesCol === "abc")
|
||||||
|
assert(estimator.getNumWorkers === 2)
|
||||||
|
assert(estimator.getNumRound === 11)
|
||||||
|
assert(estimator.getMaxDepth === 5)
|
||||||
|
|
||||||
|
val xgbParams1: Map[String, Any] = Map(
|
||||||
|
"maxDepth" -> 5,
|
||||||
|
"features_col" -> "abc",
|
||||||
|
"numWorkers" -> 2,
|
||||||
|
"num_round" -> 11
|
||||||
|
)
|
||||||
|
val estimator1 = new XGBoostClassifier(xgbParams1)
|
||||||
|
assert(estimator1.getFeaturesCol === "abc")
|
||||||
|
assert(estimator1.getNumWorkers === 2)
|
||||||
|
assert(estimator1.getNumRound === 11)
|
||||||
|
assert(estimator1.getMaxDepth === 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("get xgboost parameters") {
|
||||||
|
val params: Map[String, Any] = Map(
|
||||||
|
"max_depth" -> 5,
|
||||||
|
"featuresCol" -> "abc",
|
||||||
|
"label" -> "class",
|
||||||
|
"num_workers" -> 2,
|
||||||
|
"tree_method" -> "hist",
|
||||||
|
"numRound" -> 11,
|
||||||
|
"not_exist_parameters" -> "hello"
|
||||||
|
)
|
||||||
|
val estimator = new XGBoostClassifier(params)
|
||||||
|
val xgbParams = estimator.getXGBoostParams
|
||||||
|
assert(xgbParams.size === 2)
|
||||||
|
assert(xgbParams.contains("max_depth") && xgbParams.contains("tree_method"))
|
||||||
|
}
|
||||||
|
|
||||||
test("nthread") {
|
test("nthread") {
|
||||||
val classifier = new XGBoostClassifier().setNthread(100)
|
val classifier = new XGBoostClassifier().setNthread(100)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user