[jvm-packages] bring back camel case variants of parameters (#10845)
This commit is contained in:
parent
2179baa50c
commit
a049490cdb
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.google.common.base.CaseFormat
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
||||
@ -28,20 +29,23 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
||||
* @param xgboostParams XGBoost style parameters
|
||||
*/
|
||||
def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((name, paramValue) <- xgboostParams) {
|
||||
params.find(_.name == name).foreach {
|
||||
case _: DoubleParam =>
|
||||
set(name, paramValue.toString.toDouble)
|
||||
case _: BooleanParam =>
|
||||
set(name, paramValue.toString.toBoolean)
|
||||
case _: IntParam =>
|
||||
set(name, paramValue.toString.toInt)
|
||||
case _: FloatParam =>
|
||||
set(name, paramValue.toString.toFloat)
|
||||
case _: LongParam =>
|
||||
set(name, paramValue.toString.toLong)
|
||||
case _: Param[_] =>
|
||||
set(name, paramValue)
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
val lowerCamelName = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
val lowerName = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, paramName)
|
||||
val qualifiedNames = mutable.Set(paramName, lowerName, lowerCamelName)
|
||||
params.find(p => qualifiedNames.contains(p.name)) foreach {
|
||||
case p: DoubleParam =>
|
||||
set(p.name, paramValue.toString.toDouble)
|
||||
case p: BooleanParam =>
|
||||
set(p.name, paramValue.toString.toBoolean)
|
||||
case p: IntParam =>
|
||||
set(p.name, paramValue.toString.toInt)
|
||||
case p: FloatParam =>
|
||||
set(p.name, paramValue.toString.toFloat)
|
||||
case p: LongParam =>
|
||||
set(p.name, paramValue.toString.toLong)
|
||||
case p: Param[_] =>
|
||||
set(p.name, paramValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -49,7 +53,7 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams {
|
||||
/**
|
||||
* Convert the user-supplied parameters to the XGBoost parameters.
|
||||
*
|
||||
* Note that this also contains jvm-specific parameters.
|
||||
* Note that this doesn't contain jvm-specific parameters.
|
||||
*/
|
||||
def getXGBoostParams: Map[String, Any] = {
|
||||
val xgboostParams = new mutable.HashMap[String, Any]()
|
||||
|
||||
@ -69,6 +69,48 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
|
||||
assert(model.getContribPredictionCol === "contrib")
|
||||
}
|
||||
|
||||
test("camel case parameters") {
|
||||
val xgbParams: Map[String, Any] = Map(
|
||||
"max_depth" -> 5,
|
||||
"featuresCol" -> "abc",
|
||||
"num_workers" -> 2,
|
||||
"numRound" -> 11
|
||||
)
|
||||
val estimator = new XGBoostClassifier(xgbParams)
|
||||
assert(estimator.getFeaturesCol === "abc")
|
||||
assert(estimator.getNumWorkers === 2)
|
||||
assert(estimator.getNumRound === 11)
|
||||
assert(estimator.getMaxDepth === 5)
|
||||
|
||||
val xgbParams1: Map[String, Any] = Map(
|
||||
"maxDepth" -> 5,
|
||||
"features_col" -> "abc",
|
||||
"numWorkers" -> 2,
|
||||
"num_round" -> 11
|
||||
)
|
||||
val estimator1 = new XGBoostClassifier(xgbParams1)
|
||||
assert(estimator1.getFeaturesCol === "abc")
|
||||
assert(estimator1.getNumWorkers === 2)
|
||||
assert(estimator1.getNumRound === 11)
|
||||
assert(estimator1.getMaxDepth === 5)
|
||||
}
|
||||
|
||||
test("get xgboost parameters") {
|
||||
val params: Map[String, Any] = Map(
|
||||
"max_depth" -> 5,
|
||||
"featuresCol" -> "abc",
|
||||
"label" -> "class",
|
||||
"num_workers" -> 2,
|
||||
"tree_method" -> "hist",
|
||||
"numRound" -> 11,
|
||||
"not_exist_parameters" -> "hello"
|
||||
)
|
||||
val estimator = new XGBoostClassifier(params)
|
||||
val xgbParams = estimator.getXGBoostParams
|
||||
assert(xgbParams.size === 2)
|
||||
assert(xgbParams.contains("max_depth") && xgbParams.contains("tree_method"))
|
||||
}
|
||||
|
||||
test("nthread") {
|
||||
val classifier = new XGBoostClassifier().setNthread(100)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user